From 89c65752dc633cd80c30133563ac1fef464c74d8 Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Tue, 29 Oct 2019 15:46:05 -0700 Subject: [PATCH] Update the signature of vector script functions. (#48653) Previously the functions accepted a doc values reference, whereas they now accept the name of the vector field. Here's an example of how a vector function was called before and after the change. ``` Before: cosineSimilarity(params.query_vector, doc['field']) After: cosineSimilarity(params.query_vector, 'field') ``` This seems more intuitive, since we don't allow direct access to vector doc values and the the meaning of `doc['field']` is unclear. The PR makes the following changes (broken into distinct commits): * Add new function signatures of the form `function(params.query_vector, 'field')` and deprecates the old ones. Because Painless doesn't allow two methods with the same name and number of arguments, we allow a generic `Object` to be passed in to the function and decide on the behavior through an `instanceof` check. * Refactor the class bindings so that the document field is passed to the constructor instead of the instance method. This allows us to avoid retrieving the vector doc values on every function invocation, which gives a tiny speed-up in benchmarks. Note that this PR adds new signatures for the sparse vector functions too, even though sparse vectors are deprecated. It seemed simplest to understand (for both us and users) to keep everything symmetric between dense and sparse vectors. --- .../index.asciidoc | 8 +- docs/reference/migration/migrate_7_6.asciidoc | 7 + .../vectors/vector-functions.asciidoc | 18 +- .../org/elasticsearch/script/ScoreScript.java | 2 +- .../test/vectors/10_dense_vector_basic.yml | 29 ++- .../test/vectors/15_dense_vector_l1l2.yml | 4 +- .../vectors/20_dense_vector_special_cases.yml | 14 +- .../test/vectors/30_sparse_vector_basic.yml | 28 ++- .../test/vectors/35_sparse_vector_l1l2.yml | 4 +- .../40_sparse_vector_special_cases.yml | 20 +- .../xpack/vectors/query/ScoreScriptUtils.java | 148 ++++++++---- .../xpack/vectors/query/whitelist.txt | 16 +- .../vectors/query/ScoreScriptUtilsTests.java | 223 ++++++++++++++---- 13 files changed, 377 insertions(+), 144 deletions(-) diff --git a/docs/painless/painless-api-reference/painless-api-reference-score/index.asciidoc b/docs/painless/painless-api-reference/painless-api-reference-score/index.asciidoc index d355a495e06..ea272a3e392 100644 --- a/docs/painless/painless-api-reference/painless-api-reference-score/index.asciidoc +++ b/docs/painless/painless-api-reference/painless-api-reference-score/index.asciidoc @@ -10,8 +10,8 @@ The following specialized API is available in the Score context. ==== Static Methods The following methods are directly callable without a class/instance qualifier. Note parameters denoted by a (*) are treated as read-only values. -* double cosineSimilarity(List *, VectorScriptDocValues.DenseVectorScriptDocValues) -* double cosineSimilaritySparse(Map *, VectorScriptDocValues.SparseVectorScriptDocValues) +* double cosineSimilarity(List *, String) +* double cosineSimilaritySparse(Map *, String) * double decayDateExp(String *, String *, String *, double *, JodaCompatibleZonedDateTime) * double decayDateGauss(String *, String *, String *, double *, JodaCompatibleZonedDateTime) * double decayDateLinear(String *, String *, String *, double *, JodaCompatibleZonedDateTime) @@ -21,8 +21,8 @@ The following methods are directly callable without a class/instance qualifier. * double decayNumericExp(double *, double *, double *, double *, double) * double decayNumericGauss(double *, double *, double *, double *, double) * double decayNumericLinear(double *, double *, double *, double *, double) -* double dotProduct(List, VectorScriptDocValues.DenseVectorScriptDocValues) -* double dotProductSparse(Map *, VectorScriptDocValues.SparseVectorScriptDocValues) +* double dotProduct(List, String) +* double dotProductSparse(Map *, String) * double randomScore(int *) * double randomScore(int *, String *) * double saturation(double, double) diff --git a/docs/reference/migration/migrate_7_6.asciidoc b/docs/reference/migration/migrate_7_6.asciidoc index d8b50c8be2c..a3c99fa41f3 100644 --- a/docs/reference/migration/migrate_7_6.asciidoc +++ b/docs/reference/migration/migrate_7_6.asciidoc @@ -29,3 +29,10 @@ We have not seen much interest in this experimental field type, and don't see a clear use case as it's currently designed. If you have feedback or suggestions around sparse vector functionality, please let us know through GitHub or the 'discuss' forums. + +[discrete] +==== Update to vector function signatures +The vector functions of the form `function(query, doc['field'])` are +deprecated, and the form `function(query, 'field')` should be used instead. +For example, `cosineSimilarity(query, doc['field'])` is replaced by +`cosineSimilarity(query, 'field')`. \ No newline at end of file diff --git a/docs/reference/vectors/vector-functions.asciidoc b/docs/reference/vectors/vector-functions.asciidoc index 4a23703b7ae..9db4757f035 100644 --- a/docs/reference/vectors/vector-functions.asciidoc +++ b/docs/reference/vectors/vector-functions.asciidoc @@ -68,7 +68,7 @@ GET my_index/_search } }, "script": { - "source": "cosineSimilarity(params.query_vector, doc['my_dense_vector']) + 1.0", <2> + "source": "cosineSimilarity(params.query_vector, 'my_dense_vector') + 1.0", <2> "params": { "query_vector": [4, 3.4, -0.2] <3> } @@ -105,7 +105,7 @@ GET my_index/_search }, "script": { "source": """ - double value = dotProduct(params.query_vector, doc['my_dense_vector']); + double value = dotProduct(params.query_vector, 'my_dense_vector'); return sigmoid(1, Math.E, -value); <1> """, "params": { @@ -139,7 +139,7 @@ GET my_index/_search } }, "script": { - "source": "1 / (1 + l1norm(params.queryVector, doc['my_dense_vector']))", <1> + "source": "1 / (1 + l1norm(params.queryVector, 'my_dense_vector'))", <1> "params": { "queryVector": [4, 3.4, -0.2] } @@ -178,7 +178,7 @@ GET my_index/_search } }, "script": { - "source": "1 / (1 + l2norm(params.queryVector, doc['my_dense_vector']))", + "source": "1 / (1 + l2norm(params.queryVector, 'my_dense_vector'))", "params": { "queryVector": [4, 3.4, -0.2] } @@ -196,7 +196,7 @@ You can check if a document has a value for the field `my_vector` by [source,js] -------------------------------------------------- -"source": "doc['my_vector'].size() == 0 ? 0 : cosineSimilarity(params.queryVector, doc['my_vector'])" +"source": "doc['my_vector'].size() == 0 ? 0 : cosineSimilarity(params.queryVector, 'my_vector')" -------------------------------------------------- // NOTCONSOLE @@ -262,7 +262,7 @@ GET my_sparse_index/_search } }, "script": { - "source": "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector']) + 1.0", + "source": "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector') + 1.0", "params": { "query_vector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} } @@ -294,7 +294,7 @@ GET my_sparse_index/_search }, "script": { "source": """ - double value = dotProductSparse(params.query_vector, doc['my_sparse_vector']); + double value = dotProductSparse(params.query_vector, 'my_sparse_vector'); return sigmoid(1, Math.E, -value); """, "params": { @@ -327,7 +327,7 @@ GET my_sparse_index/_search } }, "script": { - "source": "1 / (1 + l1normSparse(params.queryVector, doc['my_sparse_vector']))", + "source": "1 / (1 + l1normSparse(params.queryVector, 'my_sparse_vector'))", "params": { "queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} } @@ -358,7 +358,7 @@ GET my_sparse_index/_search } }, "script": { - "source": "1 / (1 + l2normSparse(params.queryVector, doc['my_sparse_vector']))", + "source": "1 / (1 + l2normSparse(params.queryVector, 'my_sparse_vector'))", "params": { "queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} } diff --git a/server/src/main/java/org/elasticsearch/script/ScoreScript.java b/server/src/main/java/org/elasticsearch/script/ScoreScript.java index 65417458b38..b95f4991965 100644 --- a/server/src/main/java/org/elasticsearch/script/ScoreScript.java +++ b/server/src/main/java/org/elasticsearch/script/ScoreScript.java @@ -119,7 +119,7 @@ public abstract class ScoreScript { } /** The doc lookup for the Lucene segment this script was created for. */ - public final Map> getDoc() { + public Map> getDoc() { return leafLookup.doc(); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml index 7631709a94c..b94165271ae 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml @@ -1,6 +1,6 @@ setup: - skip: - features: headers + features: [headers, warnings] version: " - 7.2.99" reason: "dense_vector dims parameter was added from 7.3" @@ -52,7 +52,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProduct(params.query_vector, doc['my_dense_vector'])" + source: "dotProduct(params.query_vector, 'my_dense_vector')" params: query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] @@ -82,7 +82,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] @@ -99,3 +99,26 @@ setup: - match: {hits.hits.2._id: "1"} - gte: {hits.hits.2._score: 0.78} - lte: {hits.hits.2._score: 0.791} + +--- +"Deprecated function signature": + - do: + headers: + Content-Type: application/json + warnings: + - The vector functions of the form function(query, doc['field']) are deprecated, and the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by cosineSimilarity(query, 'field'). + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + params: + query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + + - match: {hits.total: 3} + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.2._id: "1"} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/15_dense_vector_l1l2.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/15_dense_vector_l1l2.yml index dbb274d0776..882d11566df 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/15_dense_vector_l1l2.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/15_dense_vector_l1l2.yml @@ -53,7 +53,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l1norm(params.query_vector, doc['my_dense_vector'])" + source: "l1norm(params.query_vector, 'my_dense_vector')" params: query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] @@ -83,7 +83,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l2norm(params.query_vector, doc['my_dense_vector'])" + source: "l2norm(params.query_vector, 'my_dense_vector')" params: query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml index 4d9394dc2b7..03ae66152d1 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml @@ -62,7 +62,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [10, 10, 10] @@ -81,7 +81,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [10.0, 10.0, 10.0] @@ -111,7 +111,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [1, 2, 3, 4] - match: { error.root_cause.0.type: "script_exception" } @@ -125,7 +125,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProduct(params.query_vector, doc['my_dense_vector'])" + source: "dotProduct(params.query_vector, 'my_dense_vector')" params: query_vector: [1, 2, 3, 4] - match: { error.root_cause.0.type: "script_exception" } @@ -161,7 +161,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [10.0, 10.0, 10.0] - match: { error.root_cause.0.type: "script_exception" } @@ -177,7 +177,7 @@ setup: script_score: query: {match_all: {} } script: - source: "doc['my_dense_vector'].size() == 0 ? 0 : cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "doc['my_dense_vector'].size() == 0 ? 0 : cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [10.0, 10.0, 10.0] @@ -208,7 +208,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProductSparse(params.query_vector, doc['my_dense_vector'])" + source: "dotProductSparse(params.query_vector, 'my_dense_vector')" params: query_vector: {"2": 0.5, "10" : 111.3, "3": 44} - match: { error.root_cause.0.type: "script_exception" } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/30_sparse_vector_basic.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/30_sparse_vector_basic.yml index e184fd0ce93..803d82b0705 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/30_sparse_vector_basic.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/30_sparse_vector_basic.yml @@ -55,7 +55,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'])" + source: "dotProductSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} @@ -87,7 +87,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} @@ -104,3 +104,27 @@ setup: - match: {hits.hits.2._id: "1"} - gte: {hits.hits.2._score: 0.78} - lte: {hits.hits.2._score: 0.791} + +--- +"Deprecated function signature": + - do: + headers: + Content-Type: application/json + warnings: + - The [sparse_vector] field type is deprecated and will be removed in 8.0. + - The vector functions of the form function(query, doc['field']) are deprecated, and the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by cosineSimilarity(query, 'field'). + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + params: + query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} + + - match: {hits.total: 3} + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.2._id: "1"} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/35_sparse_vector_l1l2.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/35_sparse_vector_l1l2.yml index 3a6ed9fd561..8a1ec0d3cdd 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/35_sparse_vector_l1l2.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/35_sparse_vector_l1l2.yml @@ -55,7 +55,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l1normSparse(params.query_vector, doc['my_sparse_vector'])" + source: "l1normSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} @@ -88,7 +88,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l2normSparse(params.query_vector, doc['my_sparse_vector'])" + source: "l2normSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/40_sparse_vector_special_cases.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/40_sparse_vector_special_cases.yml index b1c6c756c0b..0bfd800b759 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/40_sparse_vector_special_cases.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/40_sparse_vector_special_cases.yml @@ -61,7 +61,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10} @@ -83,7 +83,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10.0} @@ -127,7 +127,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10.0} - match: { error.root_cause.0.type: "script_exception" } @@ -145,7 +145,7 @@ setup: script_score: query: {match_all: {} } script: - source: "doc['my_sparse_vector'].size() == 0 ? 0 : cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "doc['my_sparse_vector'].size() == 0 ? 0 : cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10.0} @@ -194,7 +194,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"100": -200.0, "11" : 300.33, "12": -34.8988, "2": 230.0, "30": 15.555} @@ -229,7 +229,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProduct(params.query_vector, doc['my_sparse_vector'])" + source: "dotProduct(params.query_vector, 'my_sparse_vector')" params: query_vector: [0.5, 111] - match: { error.root_cause.0.type: "script_exception" } @@ -272,7 +272,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'])" + source: "dotProductSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10, "5": 5} @@ -303,7 +303,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10, "5" : 5} @@ -333,7 +333,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l1normSparse(params.query_vector, doc['my_sparse_vector'])" + source: "l1normSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10, "5": 5} @@ -360,7 +360,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l2normSparse(params.query_vector, doc['my_sparse_vector'])" + source: "l2normSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10, "5": 5} diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java index 91f2fc343b1..b4106654f77 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java @@ -9,12 +9,16 @@ package org.elasticsearch.xpack.vectors.query; import org.apache.logging.log4j.LogManager; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.script.ScoreScript; import org.elasticsearch.xpack.vectors.mapper.SparseVectorFieldMapper; import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; +import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.DenseVectorScriptDocValues; +import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.SparseVectorScriptDocValues; +import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; import java.util.Map; @@ -22,6 +26,10 @@ import java.util.Map; import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSparseDimsFloatValues; public class ScoreScriptUtils { + private static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(ScoreScriptUtils.class)); + static final String DEPRECATION_MESSAGE = "The vector functions of the form function(query, doc['field']) are deprecated, and " + + "the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by " + + "cosineSimilarity(query, 'field')."; //**************FUNCTIONS FOR DENSE VECTORS // Functions are implemented as classes to accept a hidden parameter scoreScript that contains some index settings. @@ -31,9 +39,12 @@ public class ScoreScriptUtils { public static class DenseVectorFunction { final ScoreScript scoreScript; final float[] queryVector; + final VectorScriptDocValues.DenseVectorScriptDocValues docValues; - public DenseVectorFunction(ScoreScript scoreScript, List queryVector) { - this(scoreScript, queryVector, false); + public DenseVectorFunction(ScoreScript scoreScript, + List queryVector, + Object field) { + this(scoreScript, queryVector, field, false); } /** @@ -45,6 +56,7 @@ public class ScoreScriptUtils { */ public DenseVectorFunction(ScoreScript scoreScript, List queryVector, + Object field, boolean normalizeQuery) { this.scoreScript = scoreScript; @@ -62,9 +74,28 @@ public class ScoreScriptUtils { this.queryVector[dim] /= queryMagnitude; } } + + if (field instanceof String) { + String fieldName = (String) field; + docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(fieldName); + } else if (field instanceof DenseVectorScriptDocValues) { + docValues = (DenseVectorScriptDocValues) field; + deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE); + } else { + throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " + + "VectorScriptDocValues"); + } } - public void validateDocVector(BytesRef vector) { + BytesRef getEncodedVector() { + try { + docValues.setNextDocId(scoreScript._getDocId()); + } catch (IOException e) { + throw ExceptionsHelper.convertToElastic(e); + } + + // Validate the encoded vector's length. + BytesRef vector = docValues.getEncodedValue(); if (vector == null) { throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } @@ -74,20 +105,21 @@ public class ScoreScriptUtils { throw new IllegalArgumentException("The query vector has a different number of dimensions [" + queryVector.length + "] than the document vectors [" + vectorLength + "]."); } + return vector; } } // Calculate l1 norm (Manhattan distance) between a query's dense vector and documents' dense vectors public static final class L1Norm extends DenseVectorFunction { - public L1Norm(ScoreScript scoreScript, List queryVector) { - super(scoreScript, queryVector); + public L1Norm(ScoreScript scoreScript, List queryVector, Object field) { + super(scoreScript, queryVector, field); } - public double l1norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); + public double l1norm() { + BytesRef vector = getEncodedVector(); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); + double l1norm = 0; for (float queryValue : queryVector) { @@ -100,13 +132,12 @@ public class ScoreScriptUtils { // Calculate l2 norm (Euclidean distance) between a query's dense vector and documents' dense vectors public static final class L2Norm extends DenseVectorFunction { - public L2Norm(ScoreScript scoreScript, List queryVector) { - super(scoreScript, queryVector); + public L2Norm(ScoreScript scoreScript, List queryVector, Object field) { + super(scoreScript, queryVector, field); } - public double l2norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); + public double l2norm() { + BytesRef vector = getEncodedVector(); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); double l2norm = 0; @@ -121,13 +152,12 @@ public class ScoreScriptUtils { // Calculate a dot product between a query's dense vector and documents' dense vectors public static final class DotProduct extends DenseVectorFunction { - public DotProduct(ScoreScript scoreScript, List queryVector) { - super(scoreScript, queryVector); + public DotProduct(ScoreScript scoreScript, List queryVector, Object field) { + super(scoreScript, queryVector, field); } - public double dotProduct(VectorScriptDocValues.DenseVectorScriptDocValues dvs){ - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); + public double dotProduct() { + BytesRef vector = getEncodedVector(); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); double dotProduct = 0; @@ -141,14 +171,12 @@ public class ScoreScriptUtils { // Calculate cosine similarity between a query's dense vector and documents' dense vectors public static final class CosineSimilarity extends DenseVectorFunction { - public CosineSimilarity(ScoreScript scoreScript, List queryVector) { - super(scoreScript, queryVector, true); + public CosineSimilarity(ScoreScript scoreScript, List queryVector, Object field) { + super(scoreScript, queryVector, field, true); } - public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); - + public double cosineSimilarity() { + BytesRef vector = getEncodedVector(); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); double dotProduct = 0.0; @@ -176,15 +204,17 @@ public class ScoreScriptUtils { // per script execution for all documents. public static class SparseVectorFunction { - static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(SparseVectorFunction.class)); - final ScoreScript scoreScript; final float[] queryValues; final int[] queryDims; + final VectorScriptDocValues.SparseVectorScriptDocValues docValues; + // prepare queryVector once per script execution // queryVector represents a map of dimensions to values - public SparseVectorFunction(ScoreScript scoreScript, Map queryVector) { + public SparseVectorFunction(ScoreScript scoreScript, + Map queryVector, + Object field) { this.scoreScript = scoreScript; //break vector into two arrays dims and values int n = queryVector.size(); @@ -203,28 +233,46 @@ public class ScoreScriptUtils { // Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions sortSparseDimsFloatValues(queryDims, queryValues, n); + if (field instanceof String) { + String fieldName = (String) field; + docValues = (SparseVectorScriptDocValues) scoreScript.getDoc().get(fieldName); + } else if (field instanceof SparseVectorScriptDocValues) { + docValues = (SparseVectorScriptDocValues) field; + deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE); + } else { + throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " + + "VectorScriptDocValues"); + } + deprecationLogger.deprecatedAndMaybeLog("sparse_vector_function", SparseVectorFieldMapper.DEPRECATION_MESSAGE); } - public void validateDocVector(BytesRef vector) { + BytesRef getEncodedVector() { + try { + docValues.setNextDocId(scoreScript._getDocId()); + } catch (IOException e) { + throw ExceptionsHelper.convertToElastic(e); + } + + BytesRef vector = docValues.getEncodedValue(); if (vector == null) { throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } + return vector; } } // Calculate l1 norm (Manhattan distance) between a query's sparse vector and documents' sparse vectors public static final class L1NormSparse extends SparseVectorFunction { - public L1NormSparse(ScoreScript scoreScript,Map queryVector) { - super(scoreScript, queryVector); + public L1NormSparse(ScoreScript scoreScript,Map queryVector, Object docVector) { + super(scoreScript, queryVector, docVector); } - public double l1normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); - + public double l1normSparse() { + BytesRef vector = getEncodedVector(); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); + int queryIndex = 0; int docIndex = 0; double l1norm = 0; @@ -255,16 +303,15 @@ public class ScoreScriptUtils { // Calculate l2 norm (Euclidean distance) between a query's sparse vector and documents' sparse vectors public static final class L2NormSparse extends SparseVectorFunction { - public L2NormSparse(ScoreScript scoreScript, Map queryVector) { - super(scoreScript, queryVector); + public L2NormSparse(ScoreScript scoreScript, Map queryVector, Object docVector) { + super(scoreScript, queryVector, docVector); } - public double l2normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); - + public double l2normSparse() { + BytesRef vector = getEncodedVector(); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); + int queryIndex = 0; int docIndex = 0; double l2norm = 0; @@ -298,16 +345,15 @@ public class ScoreScriptUtils { // Calculate a dot product between a query's sparse vector and documents' sparse vectors public static final class DotProductSparse extends SparseVectorFunction { - public DotProductSparse(ScoreScript scoreScript, Map queryVector) { - super(scoreScript, queryVector); + public DotProductSparse(ScoreScript scoreScript, Map queryVector, Object docVector) { + super(scoreScript, queryVector, docVector); } - public double dotProductSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); - + public double dotProductSparse() { + BytesRef vector = getEncodedVector(); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); + return intDotProductSparse(queryValues, queryDims, docValues, docDims); } } @@ -316,8 +362,8 @@ public class ScoreScriptUtils { public static final class CosineSimilaritySparse extends SparseVectorFunction { final double queryVectorMagnitude; - public CosineSimilaritySparse(ScoreScript scoreScript, Map queryVector) { - super(scoreScript, queryVector); + public CosineSimilaritySparse(ScoreScript scoreScript, Map queryVector, Object docVector) { + super(scoreScript, queryVector, docVector); double dotProduct = 0; for (int i = 0; i< queryDims.length; i++) { dotProduct += queryValues[i] * queryValues[i]; @@ -325,10 +371,8 @@ public class ScoreScriptUtils { this.queryVectorMagnitude = Math.sqrt(dotProduct); } - public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); - + public double cosineSimilaritySparse() { + BytesRef vector = getEncodedVector(); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); diff --git a/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt b/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt index 42d6e6d0b0f..73155bf1333 100644 --- a/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt +++ b/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt @@ -13,12 +13,12 @@ class org.elasticsearch.script.ScoreScript @no_import { } static_import { - double l1norm(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm - double l2norm(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm - double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity - double dotProduct(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProduct - double l1normSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse - double l2normSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse - double dotProductSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse - double cosineSimilaritySparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse + double l1norm(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm + double l2norm(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm + double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity + double dotProduct(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProduct + double l1normSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse + double l2normSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse + double dotProductSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse + double cosineSimilaritySparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse } \ No newline at end of file diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java index bff87a5ac47..9aff40b359a 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2Norm; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2NormSparse; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -39,6 +40,66 @@ public class ScoreScriptUtilsTests extends ESTestCase { } private void testDenseVectorFunctions(Version indexVersion) { + String field = "vector"; + float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; + BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion); + VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class); + when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + + ScoreScript scoreScript = mock(ScoreScript.class); + when(scoreScript._getIndexVersion()).thenReturn(indexVersion); + when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs)); + + List queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f); + + // test dotProduct + DotProduct dotProduct = new DotProduct(scoreScript, queryVector, field); + double result = dotProduct.dotProduct(); + assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001); + + // test cosineSimilarity + CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector, field); + double result2 = cosineSimilarity.cosineSimilarity(); + assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001); + + // test l1Norm + L1Norm l1norm = new L1Norm(scoreScript, queryVector, field); + double result3 = l1norm.l1norm(); + assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001); + + // test l2norm + L2Norm l2norm = new L2Norm(scoreScript, queryVector, field); + double result4 = l2norm.l2norm(); + assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001); + + // test dotProduct fails when queryVector has wrong number of dims + List invalidQueryVector = Arrays.asList(0.5, 111.3); + DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector, field); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, dotProduct2::dotProduct); + assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); + + // test cosineSimilarity fails when queryVector has wrong number of dims + CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector, field); + e = expectThrows(IllegalArgumentException.class, cosineSimilarity2::cosineSimilarity); + assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); + + // test l1norm fails when queryVector has wrong number of dims + L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector, field); + e = expectThrows(IllegalArgumentException.class, l1norm2::l1norm); + assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); + + // test l2norm fails when queryVector has wrong number of dims + L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector, field); + e = expectThrows(IllegalArgumentException.class, l2norm2::l2norm); + assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); + } + + public void testDeprecatedDenseVectorFunctions() { + testDeprecatedDenseVectorFunctions(Version.V_7_4_0); + testDeprecatedDenseVectorFunctions(Version.CURRENT); + } + + private void testDeprecatedDenseVectorFunctions(Version indexVersion) { float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion); VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class); @@ -50,45 +111,53 @@ public class ScoreScriptUtilsTests extends ESTestCase { List queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f); // test dotProduct - DotProduct dotProduct = new DotProduct(scoreScript, queryVector); - double result = dotProduct.dotProduct(dvs); + DotProduct dotProduct = new DotProduct(scoreScript, queryVector, dvs); + double result = dotProduct.dotProduct(); assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); // test cosineSimilarity - CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector); - double result2 = cosineSimilarity.cosineSimilarity(dvs); + CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector, dvs); + double result2 = cosineSimilarity.cosineSimilarity(); assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); // test l1Norm - L1Norm l1norm = new L1Norm(scoreScript, queryVector); - double result3 = l1norm.l1norm(dvs); + L1Norm l1norm = new L1Norm(scoreScript, queryVector, dvs); + double result3 = l1norm.l1norm(); assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); // test l2norm - L2Norm l2norm = new L2Norm(scoreScript, queryVector); - double result4 = l2norm.l2norm(dvs); + L2Norm l2norm = new L2Norm(scoreScript, queryVector, dvs); + double result4 = l2norm.l2norm(); assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); // test dotProduct fails when queryVector has wrong number of dims List invalidQueryVector = Arrays.asList(0.5, 111.3); - DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> dotProduct2.dotProduct(dvs)); + DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector, dvs); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, dotProduct2::dotProduct); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); // test cosineSimilarity fails when queryVector has wrong number of dims - CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector); - e = expectThrows(IllegalArgumentException.class, () -> cosineSimilarity2.cosineSimilarity(dvs)); + CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector, dvs); + e = expectThrows(IllegalArgumentException.class, cosineSimilarity2::cosineSimilarity); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); // test l1norm fails when queryVector has wrong number of dims - L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector); - e = expectThrows(IllegalArgumentException.class, () -> l1norm2.l1norm(dvs)); + L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector, dvs); + e = expectThrows(IllegalArgumentException.class, l1norm2::l1norm); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); // test l2norm fails when queryVector has wrong number of dims - L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector); - e = expectThrows(IllegalArgumentException.class, () -> l2norm2.l2norm(dvs)); + L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector, dvs); + e = expectThrows(IllegalArgumentException.class, l2norm2::l2norm); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); } public void testSparseVectorFunctions() { @@ -97,12 +166,62 @@ public class ScoreScriptUtilsTests extends ESTestCase { } private void testSparseVectorFunctions(Version indexVersion) { + String field = "vector"; + int[] docVectorDims = {2, 10, 50, 113, 4545}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector( indexVersion, docVectorDims, docVectorValues, docVectorDims.length); VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + + ScoreScript scoreScript = mock(ScoreScript.class); + when(scoreScript._getIndexVersion()).thenReturn(indexVersion); + when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs)); + + Map queryVector = new HashMap() {{ + put("2", 0.5); + put("10", 111.3); + put("50", -13.0); + put("113", 14.8); + put("4545", -156.0); + }}; + + // test dotProduct + DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field); + double result = docProductSparse.dotProductSparse(); + assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); + + // test cosineSimilarity + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field); + double result2 = cosineSimilaritySparse.cosineSimilaritySparse(); + assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001); + + // test l1norm + L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field); + double result3 = l1Norm.l1normSparse(); + assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001); + + // test l2norm + L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field); + double result4 = l2Norm.l2normSparse(); + assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); + } + + public void testDeprecatedSparseVectorFunctions() { + testDeprecatedSparseVectorFunctions(Version.V_7_4_0); + testDeprecatedSparseVectorFunctions(Version.CURRENT); + } + + private void testDeprecatedSparseVectorFunctions(Version indexVersion) { + int[] docVectorDims = {2, 10, 50, 113, 4545}; + float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; + BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector( + indexVersion, docVectorDims, docVectorValues, docVectorDims.length); + VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); + when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + ScoreScript scoreScript = mock(ScoreScript.class); when(scoreScript._getIndexVersion()).thenReturn(indexVersion); @@ -115,29 +234,33 @@ public class ScoreScriptUtilsTests extends ESTestCase { }}; // test dotProduct - DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector); - double result = docProductSparse.dotProductSparse(dvs); + DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, dvs); + double result = docProductSparse.dotProductSparse(); assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE); // test cosineSimilarity - CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector); - double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, dvs); + double result2 = cosineSimilaritySparse.cosineSimilaritySparse(); assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE); // test l1norm - L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector); - double result3 = l1Norm.l1normSparse(dvs); + L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, dvs); + double result3 = l1Norm.l1normSparse(); assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE); // test l2norm - L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector); - double result4 = l2Norm.l2normSparse(dvs); + L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, dvs); + double result4 = l2Norm.l2normSparse(); assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001); - - assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE); } public void testSparseVectorMissingDimensions1() { + String field = "vector"; + // Document vector's biggest dimension > query vector's biggest dimension int[] docVectorDims = {2, 10, 50, 113, 4545, 4546}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f}; @@ -145,8 +268,11 @@ public class ScoreScriptUtilsTests extends ESTestCase { Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length); VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + ScoreScript scoreScript = mock(ScoreScript.class); when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT); + when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs)); + Map queryVector = new HashMap() {{ put("2", 0.5); put("10", 111.3); @@ -157,29 +283,33 @@ public class ScoreScriptUtilsTests extends ESTestCase { }}; // test dotProduct - DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector); - double result = docProductSparse.dotProductSparse(dvs); + DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field); + double result = docProductSparse.dotProductSparse(); assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); // test cosineSimilarity - CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector); - double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field); + double result2 = cosineSimilaritySparse.cosineSimilaritySparse(); assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); // test l1norm - L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector); - double result3 = l1Norm.l1normSparse(dvs); + L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field); + double result3 = l1Norm.l1normSparse(); assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); // test l2norm - L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector); - double result4 = l2Norm.l2normSparse(dvs); + L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field); + double result4 = l2Norm.l2normSparse(); assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001); - assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); } public void testSparseVectorMissingDimensions2() { + String field = "vector"; + // Document vector's biggest dimension < query vector's biggest dimension int[] docVectorDims = {2, 10, 50, 113, 4545, 4546}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f}; @@ -187,8 +317,11 @@ public class ScoreScriptUtilsTests extends ESTestCase { Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length); VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + ScoreScript scoreScript = mock(ScoreScript.class); when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT); + when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs)); + Map queryVector = new HashMap() {{ put("2", 0.5); put("10", 111.3); @@ -199,25 +332,27 @@ public class ScoreScriptUtilsTests extends ESTestCase { }}; // test dotProduct - DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector); - double result = docProductSparse.dotProductSparse(dvs); + DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field); + double result = docProductSparse.dotProductSparse(); assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); // test cosineSimilarity - CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector); - double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field); + double result2 = cosineSimilaritySparse.cosineSimilaritySparse(); assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); // test l1norm - L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector); - double result3 = l1Norm.l1normSparse(dvs); + L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field); + double result3 = l1Norm.l1normSparse(); assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); // test l2norm - L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector); - double result4 = l2Norm.l2normSparse(dvs); + L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field); + double result4 = l2Norm.l2normSparse(); assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001); - assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); } }