Convert dense/sparse vector field mappers to Parametrized form (#62992)

Also adds a proper MapperTestCase test for dense vectors.

Relates to #62988
This commit is contained in:
Alan Woodward 2020-09-30 16:08:17 +01:00 committed by Alan Woodward
parent 3838fe1fd4
commit 675d18f6ea
9 changed files with 314 additions and 441 deletions

View File

@ -96,13 +96,17 @@ public abstract class MapperServiceTestCase extends ESTestCase {
return createMapperService(mappings).documentMapper();
}
protected final DocumentMapper createDocumentMapper(Version version, XContentBuilder mappings) throws IOException {
return createMapperService(version, mappings).documentMapper();
}
protected final DocumentMapper createDocumentMapper(String type, String mappings) throws IOException {
MapperService mapperService = createMapperService(mapping(b -> {}));
merge(type, mapperService, mappings);
return mapperService.documentMapper();
}
protected final MapperService createMapperService(XContentBuilder mappings) throws IOException {
protected MapperService createMapperService(XContentBuilder mappings) throws IOException {
return createMapperService(Version.CURRENT, mappings);
}

View File

@ -217,7 +217,7 @@ public abstract class MapperTestCase extends MapperServiceTestCase {
minimalMapping(b);
}
public final void testMeta() throws IOException {
public void testMeta() throws IOException {
assumeTrue("Field doesn't support meta", supportsMeta());
XContentBuilder mapping = fieldMapping(
b -> {

View File

@ -1,5 +1,4 @@
apply plugin: 'elasticsearch.esplugin'
apply plugin: 'elasticsearch.internal-cluster-test'
esplugin {
name 'vectors'

View File

@ -1,212 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.vectors.mapper;
import org.apache.lucene.document.BinaryDocValuesField;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.compress.CompressedXContent;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.IndexService;
import org.elasticsearch.index.mapper.DocumentMapper;
import org.elasticsearch.index.mapper.DocumentMapperParser;
import org.elasticsearch.index.mapper.FieldMapperTestCase;
import org.elasticsearch.index.mapper.MapperParsingException;
import org.elasticsearch.index.mapper.ParsedDocument;
import org.elasticsearch.index.mapper.SourceToParse;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
import org.elasticsearch.xpack.vectors.Vectors;
import org.junit.Before;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.Set;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
public class DenseVectorFieldMapperTests extends FieldMapperTestCase<DenseVectorFieldMapper.Builder> {
@Override
protected DenseVectorFieldMapper.Builder newBuilder() {
return new DenseVectorFieldMapper.Builder("densevector").dims(4);
}
@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return pluginList(Vectors.class, LocalStateCompositeXPackPlugin.class);
}
@Override
protected Set<String> unsupportedProperties() {
return org.elasticsearch.common.collect.Set.of("analyzer", "similarity", "doc_values", "store", "index");
}
@Before
public void addModifiers() {
addModifier("dims", false, (a, b) -> {
a.dims(3);
b.dims(4);
});
}
// this allows to set indexVersion as it is a private setting
@Override
protected boolean forbidPrivateIndexSettings() {
return false;
}
public void testMappingExceedDimsLimit() throws IOException {
IndexService indexService = createIndex("test-index");
DocumentMapperParser parser = indexService.mapperService().documentMapperParser();
String mapping = Strings.toString(XContentFactory.jsonBuilder()
.startObject()
.startObject("_doc")
.startObject("properties")
.startObject("my-dense-vector").field("type", "dense_vector").field("dims", DenseVectorFieldMapper.MAX_DIMS_COUNT + 1)
.endObject()
.endObject()
.endObject()
.endObject());
MapperParsingException e = expectThrows(MapperParsingException.class, () -> parser.parse("_doc", new CompressedXContent(mapping)));
assertEquals(e.getMessage(),
"The number of dimensions for field [my-dense-vector] should be in the range [1, 2048] but was [2049]");
}
public void testDefaults() throws Exception {
Version indexVersion = Version.CURRENT;
IndexService indexService = createIndex("test-index");
DocumentMapperParser parser = indexService.mapperService().documentMapperParser();
String mapping = Strings.toString(XContentFactory.jsonBuilder()
.startObject()
.startObject("_doc")
.startObject("properties")
.startObject("my-dense-vector").field("type", "dense_vector").field("dims", 3)
.endObject()
.endObject()
.endObject()
.endObject());
DocumentMapper mapper = parser.parse("_doc", new CompressedXContent(mapping));
float[] validVector = {-12.1f, 100.7f, -4};
double dotProduct = 0.0f;
for (float value: validVector) {
dotProduct += value * value;
}
float expectedMagnitude = (float) Math.sqrt(dotProduct);
ParsedDocument doc1 = mapper.parse(new SourceToParse("test-index", "_doc", "1", BytesReference
.bytes(XContentFactory.jsonBuilder()
.startObject()
.startArray("my-dense-vector").value(validVector[0]).value(validVector[1]).value(validVector[2]).endArray()
.endObject()),
XContentType.JSON));
IndexableField[] fields = doc1.rootDoc().getFields("my-dense-vector");
assertEquals(1, fields.length);
assertThat(fields[0], instanceOf(BinaryDocValuesField.class));
// assert that after decoding the indexed value is equal to expected
BytesRef vectorBR = fields[0].binaryValue();
float[] decodedValues = decodeDenseVector(indexVersion, vectorBR);
float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(indexVersion, vectorBR);
assertEquals(expectedMagnitude, decodedMagnitude, 0.001f);
assertArrayEquals(
"Decoded dense vector values is not equal to the indexed one.",
validVector,
decodedValues,
0.001f
);
}
public void testAddDocumentsToIndexBefore_V_7_5_0() throws Exception {
Version indexVersion = Version.V_7_4_0;
IndexService indexService = createIndex("test-index7_4",
Settings.builder().put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), indexVersion).build());
DocumentMapperParser parser = indexService.mapperService().documentMapperParser();
String mapping = Strings.toString(XContentFactory.jsonBuilder()
.startObject()
.startObject("_doc")
.startObject("properties")
.startObject("my-dense-vector").field("type", "dense_vector").field("dims", 3)
.endObject()
.endObject()
.endObject()
.endObject());
DocumentMapper mapper = parser.parse("_doc", new CompressedXContent(mapping));
float[] validVector = {-12.1f, 100.7f, -4};
ParsedDocument doc1 = mapper.parse(new SourceToParse("test-index7_4", "_doc", "1", BytesReference
.bytes(XContentFactory.jsonBuilder()
.startObject()
.startArray("my-dense-vector").value(validVector[0]).value(validVector[1]).value(validVector[2]).endArray()
.endObject()),
XContentType.JSON));
IndexableField[] fields = doc1.rootDoc().getFields("my-dense-vector");
assertEquals(1, fields.length);
assertThat(fields[0], instanceOf(BinaryDocValuesField.class));
// assert that after decoding the indexed value is equal to expected
BytesRef vectorBR = fields[0].binaryValue();
float[] decodedValues = decodeDenseVector(indexVersion, vectorBR);
assertArrayEquals(
"Decoded dense vector values is not equal to the indexed one.",
validVector,
decodedValues,
0.001f
);
}
private static float[] decodeDenseVector(Version indexVersion, BytesRef encodedVector) {
int dimCount = VectorEncoderDecoder.denseVectorLength(indexVersion, encodedVector);
float[] vector = new float[dimCount];
ByteBuffer byteBuffer = ByteBuffer.wrap(encodedVector.bytes, encodedVector.offset, encodedVector.length);
for (int dim = 0; dim < dimCount; dim++) {
vector[dim] = byteBuffer.getFloat();
}
return vector;
}
public void testDocumentsWithIncorrectDims() throws Exception {
IndexService indexService = createIndex("test-index");
int dims = 3;
DocumentMapperParser parser = indexService.mapperService().documentMapperParser();
String mapping = Strings.toString(XContentFactory.jsonBuilder()
.startObject()
.startObject("_doc")
.startObject("properties")
.startObject("my-dense-vector").field("type", "dense_vector").field("dims", dims)
.endObject()
.endObject()
.endObject()
.endObject());
DocumentMapper mapper = parser.parse("_doc", new CompressedXContent(mapping));
// test that error is thrown when a document has number of dims more than defined in the mapping
float[] invalidVector = new float[dims + 1];
BytesReference invalidDoc = BytesReference.bytes(XContentFactory.jsonBuilder().startObject()
.array("my-dense-vector", invalidVector)
.endObject());
MapperParsingException e = expectThrows(MapperParsingException.class, () -> mapper.parse(
new SourceToParse("test-index", "_doc", "1", invalidDoc, XContentType.JSON)));
assertThat(e.getCause().getMessage(), containsString("has exceeded the number of dimensions [3] defined in mapping"));
// test that error is thrown when a document has number of dims less than defined in the mapping
float[] invalidVector2 = new float[dims - 1];
BytesReference invalidDoc2 = BytesReference.bytes(XContentFactory.jsonBuilder().startObject()
.array("my-dense-vector", invalidVector2)
.endObject());
MapperParsingException e2 = expectThrows(MapperParsingException.class, () -> mapper.parse(
new SourceToParse("test-index", "_doc", "2", invalidDoc2, XContentType.JSON)));
assertThat(e2.getCause().getMessage(), containsString("has number of dimensions [2] less than defined in the mapping [3]"));
}
}

View File

@ -34,8 +34,8 @@ public class Vectors extends Plugin implements MapperPlugin {
@Override
public Map<String, Mapper.TypeParser> getMappers() {
Map<String, Mapper.TypeParser> mappers = new LinkedHashMap<>();
mappers.put(DenseVectorFieldMapper.CONTENT_TYPE, new DenseVectorFieldMapper.TypeParser());
mappers.put(SparseVectorFieldMapper.CONTENT_TYPE, new SparseVectorFieldMapper.TypeParser());
mappers.put(DenseVectorFieldMapper.CONTENT_TYPE, DenseVectorFieldMapper.PARSER);
mappers.put(SparseVectorFieldMapper.CONTENT_TYPE, SparseVectorFieldMapper.PARSER);
return Collections.unmodifiableMap(mappers);
}
}

View File

@ -8,21 +8,18 @@
package org.elasticsearch.xpack.vectors.mapper;
import org.apache.lucene.document.BinaryDocValuesField;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser.Token;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.mapper.MapperParsingException;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.ParametrizedFieldMapper;
import org.elasticsearch.index.mapper.ParseContext;
import org.elasticsearch.index.mapper.SourceValueFetcher;
import org.elasticsearch.index.mapper.TextSearchInfo;
@ -36,6 +33,7 @@ import org.elasticsearch.xpack.vectors.query.VectorIndexFieldData;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.ZoneId;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
@ -45,60 +43,56 @@ import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpect
/**
* A {@link FieldMapper} for indexing a dense vector of floats.
*/
public class DenseVectorFieldMapper extends FieldMapper {
public class DenseVectorFieldMapper extends ParametrizedFieldMapper {
public static final String CONTENT_TYPE = "dense_vector";
public static short MAX_DIMS_COUNT = 2048; //maximum allowed number of dimensions
private static final byte INT_BYTES = 4;
public static class Defaults {
public static final FieldType FIELD_TYPE = new FieldType();
static {
FIELD_TYPE.setTokenized(false);
FIELD_TYPE.setIndexOptions(IndexOptions.NONE);
FIELD_TYPE.setOmitNorms(true);
FIELD_TYPE.freeze();
}
private static DenseVectorFieldMapper toType(FieldMapper in) {
return (DenseVectorFieldMapper) in;
}
public static class Builder extends FieldMapper.Builder<Builder> {
private int dims = 0;
public static class Builder extends ParametrizedFieldMapper.Builder {
public Builder(String name) {
super(name, Defaults.FIELD_TYPE);
builder = this;
Parameter<Integer> dims
= new Parameter<>("dims", false, () -> null, (n, c, o) -> XContentMapValues.nodeIntegerValue(o), m -> toType(m).dims)
.setValidator(dims -> {
if (dims == null) {
throw new MapperParsingException("Missing required parameter [dims] for field [" + name + "]");
}
if ((dims > MAX_DIMS_COUNT) || (dims < 1)) {
throw new MapperParsingException("The number of dimensions for field [" + name +
"] should be in the range [1, " + MAX_DIMS_COUNT + "] but was [" + dims + "]");
}
});
Parameter<Map<String, String>> meta = Parameter.metaParam();
final Version indexVersionCreated;
public Builder(String name, Version indexVersionCreated) {
super(name);
this.indexVersionCreated = indexVersionCreated;
}
public Builder dims(int dims) {
if ((dims > MAX_DIMS_COUNT) || (dims < 1)) {
throw new MapperParsingException("The number of dimensions for field [" + name +
"] should be in the range [1, " + MAX_DIMS_COUNT + "] but was [" + dims + "]");
}
this.dims = dims;
return this;
@Override
protected List<Parameter<?>> getParameters() {
return Arrays.asList(dims, meta);
}
@Override
public DenseVectorFieldMapper build(BuilderContext context) {
return new DenseVectorFieldMapper(
name, fieldType, new DenseVectorFieldType(buildFullName(context), dims, meta),
context.indexSettings(), multiFieldsBuilder.build(this, context), copyTo);
name,
new DenseVectorFieldType(buildFullName(context), dims.getValue(), meta.getValue()),
dims.getValue(),
indexVersionCreated,
multiFieldsBuilder.build(this, context),
copyTo.build());
}
}
public static class TypeParser implements Mapper.TypeParser {
@Override
public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserContext parserContext) throws MapperParsingException {
DenseVectorFieldMapper.Builder builder = new DenseVectorFieldMapper.Builder(name);
Object dimsField = node.remove("dims");
if (dimsField == null) {
throw new MapperParsingException("The [dims] property must be specified for field [" + name + "].");
}
int dims = XContentMapValues.nodeIntegerValue(dimsField);
return builder.dims(dims);
}
}
public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n, c.indexVersionCreated()));
public static final class DenseVectorFieldType extends MappedFieldType {
private final int dims;
@ -141,12 +135,14 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
private final Version indexCreatedVersion;
private final int dims;
private DenseVectorFieldMapper(String simpleName, FieldType fieldType, MappedFieldType mappedFieldType,
Settings indexSettings, MultiFields multiFields, CopyTo copyTo) {
super(simpleName, fieldType, mappedFieldType, multiFields, copyTo);
private DenseVectorFieldMapper(String simpleName, MappedFieldType mappedFieldType, int dims,
Version indexCreatedVersion, MultiFields multiFields, CopyTo copyTo) {
super(simpleName, mappedFieldType, multiFields, copyTo);
assert fieldType.indexOptions() == IndexOptions.NONE;
this.indexCreatedVersion = Version.indexCreated(indexSettings);
this.indexCreatedVersion = indexCreatedVersion;
this.dims = dims;
}
@Override
@ -232,20 +228,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
return true;
}
@Override
protected void doXContentBody(XContentBuilder builder, boolean includeDefaults, Params params) throws IOException {
super.doXContentBody(builder, includeDefaults, params);
builder.field("dims", fieldType().dims());
}
@Override
protected void mergeOptions(FieldMapper other, List<String> conflicts) {
DenseVectorFieldType otherType = (DenseVectorFieldType) other.fieldType();
if (this.fieldType().dims() != otherType.dims()) {
conflicts.add("mapper [" + name() + "] has different dims");
}
}
@Override
protected void parseCreateField(ParseContext context) {
throw new AssertionError("parse is implemented directly");
@ -255,4 +237,9 @@ public class DenseVectorFieldMapper extends FieldMapper {
protected String contentType() {
return CONTENT_TYPE;
}
@Override
public ParametrizedFieldMapper.Builder getMergeBuilder() {
return new Builder(simpleName(), indexCreatedVersion).init(this);
}
}

View File

@ -8,8 +8,6 @@
package org.elasticsearch.xpack.vectors.mapper;
import org.apache.lucene.document.BinaryDocValuesField;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.ArrayUtil;
@ -20,9 +18,8 @@ import org.elasticsearch.common.xcontent.XContentParser.Token;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.mapper.MapperParsingException;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.ParametrizedFieldMapper;
import org.elasticsearch.index.mapper.ParseContext;
import org.elasticsearch.index.mapper.SourceValueFetcher;
import org.elasticsearch.index.mapper.TextSearchInfo;
@ -35,6 +32,7 @@ import org.elasticsearch.xpack.vectors.query.VectorIndexFieldData;
import java.io.IOException;
import java.time.ZoneId;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
@ -44,7 +42,9 @@ import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpect
/**
* A {@link FieldMapper} for indexing a sparse vector of floats.
*/
public class SparseVectorFieldMapper extends FieldMapper {
@Deprecated
public class SparseVectorFieldMapper extends ParametrizedFieldMapper {
private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(SparseVectorFieldMapper.class);
public static final String DEPRECATION_MESSAGE = "The [sparse_vector] field type is deprecated and will be removed in 8.0.";
@ -52,40 +52,34 @@ public class SparseVectorFieldMapper extends FieldMapper {
public static short MAX_DIMS_COUNT = 1024; //maximum allowed number of dimensions
public static int MAX_DIMS_NUMBER = 65535; //maximum allowed dimension's number
public static class Defaults {
public static final FieldType FIELD_TYPE = new FieldType();
public static class Builder extends ParametrizedFieldMapper.Builder {
static {
FIELD_TYPE.setTokenized(false);
FIELD_TYPE.setIndexOptions(IndexOptions.NONE);
FIELD_TYPE.setOmitNorms(true);
FIELD_TYPE.freeze();
final Parameter<Map<String, String>> meta = Parameter.metaParam();
final Version indexCreatedVersion;
public Builder(String name, Version indexCreatedVersion) {
super(name);
this.indexCreatedVersion = indexCreatedVersion;
}
}
public static class Builder extends FieldMapper.Builder<Builder> {
public Builder(String name) {
super(name, Defaults.FIELD_TYPE);
builder = this;
@Override
protected List<Parameter<?>> getParameters() {
return Collections.singletonList(meta);
}
@Override
public SparseVectorFieldMapper build(BuilderContext context) {
return new SparseVectorFieldMapper(
name, fieldType, new SparseVectorFieldType(buildFullName(context), meta),
context.indexCreatedVersion(), multiFieldsBuilder.build(this, context), copyTo);
name, new SparseVectorFieldType(buildFullName(context), meta.getValue()),
multiFieldsBuilder.build(this, context), copyTo.build(), indexCreatedVersion);
}
}
public static class TypeParser implements Mapper.TypeParser {
@Override
public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserContext parserContext) throws MapperParsingException {
deprecationLogger.deprecate("sparse_vector", DEPRECATION_MESSAGE);
SparseVectorFieldMapper.Builder builder = new SparseVectorFieldMapper.Builder(name);
return builder;
}
}
public static final TypeParser PARSER = new TypeParser((n, c) -> {
deprecationLogger.deprecate("sparse_vector", DEPRECATION_MESSAGE);
return new Builder(n, c.indexVersionCreated());
});
public static final class SparseVectorFieldType extends MappedFieldType {
@ -128,10 +122,9 @@ public class SparseVectorFieldMapper extends FieldMapper {
private final Version indexCreatedVersion;
private SparseVectorFieldMapper(String simpleName, FieldType fieldType, MappedFieldType mappedFieldType,
Version indexCreatedVersion, MultiFields multiFields, CopyTo copyTo) {
super(simpleName, fieldType, mappedFieldType, multiFields, copyTo);
assert fieldType.indexOptions() == IndexOptions.NONE;
private SparseVectorFieldMapper(String simpleName, MappedFieldType mappedFieldType,
MultiFields multiFields, CopyTo copyTo, Version indexCreatedVersion) {
super(simpleName, mappedFieldType, multiFields, copyTo);
this.indexCreatedVersion = indexCreatedVersion;
}
@ -140,11 +133,6 @@ public class SparseVectorFieldMapper extends FieldMapper {
return (SparseVectorFieldMapper) super.clone();
}
@Override
protected void mergeOptions(FieldMapper other, List<String> conflicts) {
}
@Override
public SparseVectorFieldType fieldType() {
return (SparseVectorFieldType) super.fieldType();
@ -228,4 +216,9 @@ public class SparseVectorFieldMapper extends FieldMapper {
protected String contentType() {
return CONTENT_TYPE;
}
@Override
public ParametrizedFieldMapper.Builder getMergeBuilder() {
return new Builder(simpleName(), indexCreatedVersion).init(this);
}
}

View File

@ -0,0 +1,155 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.vectors.mapper;
import org.apache.lucene.document.BinaryDocValuesField;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.mapper.DocumentMapper;
import org.elasticsearch.index.mapper.MapperParsingException;
import org.elasticsearch.index.mapper.MapperTestCase;
import org.elasticsearch.index.mapper.ParsedDocument;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.xpack.vectors.Vectors;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.Collections;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
public class DenseVectorFieldMapperTests extends MapperTestCase {
@Override
protected Collection<? extends Plugin> getPlugins() {
return Collections.singletonList(new Vectors());
}
@Override
protected void minimalMapping(XContentBuilder b) throws IOException {
b.field("type", "dense_vector").field("dims", 4);
}
@Override
protected void writeFieldValue(XContentBuilder builder) throws IOException {
builder.startArray().value(1).value(2).value(3).value(4).endArray();
}
@Override
protected void registerParameters(ParameterChecker checker) throws IOException {
checker.registerConflictCheck("dims",
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4)),
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 5)));
}
public void testDims() {
{
Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> {
b.field("type", "dense_vector");
b.field("dims", 0);
})));
assertThat(e.getMessage(), equalTo("Failed to parse mapping [_doc]: " +
"The number of dimensions for field [field] should be in the range [1, 2048] but was [0]"));
}
{
Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> {
b.field("type", "dense_vector");
b.field("dims", 3000);
})));
assertThat(e.getMessage(), equalTo("Failed to parse mapping [_doc]: " +
"The number of dimensions for field [field] should be in the range [1, 2048] but was [3000]"));
}
{
Exception e = expectThrows(MapperParsingException.class,
() -> createMapperService(fieldMapping(b -> b.field("type", "dense_vector"))));
assertThat(e.getMessage(), equalTo("Failed to parse mapping [_doc]: Missing required parameter [dims] for field [field]"));
}
}
public void testDefaults() throws Exception {
DocumentMapper mapper = createDocumentMapper(fieldMapping(b -> b.field("type", "dense_vector").field("dims", 3)));
float[] validVector = {-12.1f, 100.7f, -4};
double dotProduct = 0.0f;
for (float value: validVector) {
dotProduct += value * value;
}
float expectedMagnitude = (float) Math.sqrt(dotProduct);
ParsedDocument doc1 = mapper.parse(source(b -> b.array("field", validVector)));
IndexableField[] fields = doc1.rootDoc().getFields("field");
assertEquals(1, fields.length);
assertThat(fields[0], instanceOf(BinaryDocValuesField.class));
// assert that after decoding the indexed value is equal to expected
BytesRef vectorBR = fields[0].binaryValue();
float[] decodedValues = decodeDenseVector(Version.CURRENT, vectorBR);
float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(Version.CURRENT, vectorBR);
assertEquals(expectedMagnitude, decodedMagnitude, 0.001f);
assertArrayEquals(
"Decoded dense vector values is not equal to the indexed one.",
validVector,
decodedValues,
0.001f
);
}
public void testAddDocumentsToIndexBefore_V_7_5_0() throws Exception {
Version indexVersion = Version.V_7_4_0;
DocumentMapper mapper
= createDocumentMapper(indexVersion, fieldMapping(b -> b.field("type", "dense_vector").field("dims", 3)));
float[] validVector = {-12.1f, 100.7f, -4};
ParsedDocument doc1 = mapper.parse(source(b -> b.array("field", validVector)));
IndexableField[] fields = doc1.rootDoc().getFields("field");
assertEquals(1, fields.length);
assertThat(fields[0], instanceOf(BinaryDocValuesField.class));
// assert that after decoding the indexed value is equal to expected
BytesRef vectorBR = fields[0].binaryValue();
float[] decodedValues = decodeDenseVector(indexVersion, vectorBR);
assertArrayEquals(
"Decoded dense vector values is not equal to the indexed one.",
validVector,
decodedValues,
0.001f
);
}
private static float[] decodeDenseVector(Version indexVersion, BytesRef encodedVector) {
int dimCount = VectorEncoderDecoder.denseVectorLength(indexVersion, encodedVector);
float[] vector = new float[dimCount];
ByteBuffer byteBuffer = ByteBuffer.wrap(encodedVector.bytes, encodedVector.offset, encodedVector.length);
for (int dim = 0; dim < dimCount; dim++) {
vector[dim] = byteBuffer.getFloat();
}
return vector;
}
public void testDocumentsWithIncorrectDims() throws Exception {
int dims = 3;
DocumentMapper mapper = createDocumentMapper(fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims)));
// test that error is thrown when a document has number of dims more than defined in the mapping
float[] invalidVector = new float[dims + 1];
MapperParsingException e = expectThrows(MapperParsingException.class,
() -> mapper.parse(source(b -> b.array("field", invalidVector))));
assertThat(e.getCause().getMessage(), containsString("has exceeded the number of dimensions [3] defined in mapping"));
// test that error is thrown when a document has number of dims less than defined in the mapping
float[] invalidVector2 = new float[dims - 1];
MapperParsingException e2 = expectThrows(MapperParsingException.class,
() -> mapper.parse(source(b -> b.array("field", invalidVector2))));
assertThat(e2.getCause().getMessage(), containsString("has number of dimensions [2] less than defined in the mapping [3]"));
}
}

View File

@ -11,30 +11,19 @@ import org.apache.lucene.document.BinaryDocValuesField;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.compress.CompressedXContent;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.IndexService;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.mapper.DocumentMapper;
import org.elasticsearch.index.mapper.DocumentMapperParser;
import org.elasticsearch.index.mapper.FieldMapperTestCase;
import org.elasticsearch.index.mapper.MapperParsingException;
import org.elasticsearch.index.mapper.MapperTestCase;
import org.elasticsearch.index.mapper.ParsedDocument;
import org.elasticsearch.index.mapper.SourceToParse;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
import org.elasticsearch.xpack.vectors.Vectors;
import org.hamcrest.Matchers;
import org.junit.Before;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@ -42,56 +31,52 @@ import java.util.stream.IntStream;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.core.IsInstanceOf.instanceOf;
public class SparseVectorFieldMapperTests extends FieldMapperTestCase<SparseVectorFieldMapper.Builder> {
private DocumentMapper mapper;
@SuppressWarnings("deprecation")
public class SparseVectorFieldMapperTests extends MapperTestCase {
@Override
protected Set<String> unsupportedProperties() {
return org.elasticsearch.common.collect.Set.of("analyzer", "similarity", "doc_values", "store", "index");
}
@Before
public void setUpMapper() throws Exception {
IndexService indexService = createIndex("test-index");
DocumentMapperParser parser = indexService.mapperService().documentMapperParser();
String mapping = Strings.toString(XContentFactory.jsonBuilder()
.startObject()
.startObject("_doc")
.startObject("properties")
.startObject("my-sparse-vector").field("type", "sparse_vector")
.endObject()
.endObject()
.endObject()
.endObject());
mapper = parser.parse("_doc", new CompressedXContent(mapping));
protected void assertParseMinimalWarnings() {
assertWarnings("The [sparse_vector] field type is deprecated and will be removed in 8.0.");
}
@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return pluginList(Vectors.class, LocalStateCompositeXPackPlugin.class);
protected void assertParseMaximalWarnings() {
assertParseMinimalWarnings();
}
// this allows to set indexVersion as it is a private setting
@Override
protected boolean forbidPrivateIndexSettings() {
return false;
protected void registerParameters(ParameterChecker checker) {
// no parameters to check
}
@Override
protected void minimalMapping(XContentBuilder b) throws IOException {
b.field("type", "sparse_vector");
}
@Override
protected void writeFieldValue(XContentBuilder builder) throws IOException {
builder.startObject().field("1", 1).endObject();
}
@Override
protected Collection<Plugin> getPlugins() {
return Collections.singletonList(new Vectors());
}
public void testDefaults() throws Exception {
Version indexVersion = Version.CURRENT;
int[] indexedDims = {65535, 50, 2};
float[] indexedValues = {0.5f, 1800f, -34567.11f};
ParsedDocument doc1 = mapper.parse(new SourceToParse("test-index", "_doc", "1", BytesReference
.bytes(XContentFactory.jsonBuilder()
.startObject()
.startObject("my-sparse-vector")
.field(Integer.toString(indexedDims[0]), indexedValues[0])
.field(Integer.toString(indexedDims[1]), indexedValues[1])
.field(Integer.toString(indexedDims[2]), indexedValues[2])
.endObject()
.endObject()),
XContentType.JSON));
IndexableField[] fields = doc1.rootDoc().getFields("my-sparse-vector");
DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping));
ParsedDocument doc1 = mapper.parse(source(b -> {
b.startObject("field");
b.field(Integer.toString(indexedDims[0]), indexedValues[0]);
b.field(Integer.toString(indexedDims[1]), indexedValues[1]);
b.field(Integer.toString(indexedDims[2]), indexedValues[2]);
b.endObject();
}));
IndexableField[] fields = doc1.rootDoc().getFields("field");
assertEquals(1, fields.length);
assertThat(fields[0], Matchers.instanceOf(BinaryDocValuesField.class));
@ -127,33 +112,18 @@ public class SparseVectorFieldMapperTests extends FieldMapperTestCase<SparseVect
public void testAddDocumentsToIndexBefore_V_7_5_0() throws Exception {
Version indexVersion = Version.V_7_4_0;
IndexService indexService = createIndex("test-index7_4",
Settings.builder().put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), indexVersion).build());
DocumentMapperParser parser = indexService.mapperService().documentMapperParser();
String mapping = Strings.toString(XContentFactory.jsonBuilder()
.startObject()
.startObject("_doc")
.startObject("properties")
.startObject("my-sparse-vector").field("type", "sparse_vector")
.endObject()
.endObject()
.endObject()
.endObject());
mapper = parser.parse("_doc", new CompressedXContent(mapping));
DocumentMapper mapper = createDocumentMapper(indexVersion, fieldMapping(this::minimalMapping));
int[] indexedDims = {65535, 50, 2};
float[] indexedValues = {0.5f, 1800f, -34567.11f};
ParsedDocument doc1 = mapper.parse(new SourceToParse("test-index7_4", "_doc", "1", BytesReference
.bytes(XContentFactory.jsonBuilder()
.startObject()
.startObject("my-sparse-vector")
.field(Integer.toString(indexedDims[0]), indexedValues[0])
.field(Integer.toString(indexedDims[1]), indexedValues[1])
.field(Integer.toString(indexedDims[2]), indexedValues[2])
.endObject()
.endObject()),
XContentType.JSON));
IndexableField[] fields = doc1.rootDoc().getFields("my-sparse-vector");
ParsedDocument doc1 = mapper.parse(source(b -> {
b.startObject("field");
b.field(Integer.toString(indexedDims[0]), indexedValues[0]);
b.field(Integer.toString(indexedDims[1]), indexedValues[1]);
b.field(Integer.toString(indexedDims[2]), indexedValues[2]);
b.endObject();
}));
IndexableField[] fields = doc1.rootDoc().getFields("field");
assertEquals(1, fields.length);
assertThat(fields[0], Matchers.instanceOf(BinaryDocValuesField.class));
@ -180,17 +150,15 @@ public class SparseVectorFieldMapperTests extends FieldMapperTestCase<SparseVect
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
}
public void testDimensionNumberValidation() {
public void testDimensionNumberValidation() throws IOException {
// 1. test for an error on negative dimension
DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping));
MapperParsingException e = expectThrows(MapperParsingException.class, () -> {
mapper.parse(new SourceToParse("test-index", "_doc", "1", BytesReference
.bytes(XContentFactory.jsonBuilder()
.startObject()
.startObject("my-sparse-vector")
.field(Integer.toString(-50), 100f)
.endObject()
.endObject()),
XContentType.JSON));
mapper.parse(source(b -> {
b.startObject("field");
b.field("-50", 100f);
b.endObject();
}));
});
assertThat(e.getCause(), instanceOf(IllegalArgumentException.class));
assertThat(e.getCause().getMessage(), containsString(
@ -198,14 +166,11 @@ public class SparseVectorFieldMapperTests extends FieldMapperTestCase<SparseVect
// 2. test for an error on a dimension greater than MAX_DIMS_NUMBER
e = expectThrows(MapperParsingException.class, () -> {
mapper.parse(new SourceToParse("test-index", "_doc", "1", BytesReference
.bytes(XContentFactory.jsonBuilder()
.startObject()
.startObject("my-sparse-vector")
.field(Integer.toString(70000), 100f)
.endObject()
.endObject()),
XContentType.JSON));
mapper.parse(source(b -> {
b.startObject("field");
b.field("70000", 100f);
b.endObject();
}));
});
assertThat(e.getCause(), instanceOf(IllegalArgumentException.class));
assertThat(e.getCause().getMessage(), containsString(
@ -213,14 +178,11 @@ public class SparseVectorFieldMapperTests extends FieldMapperTestCase<SparseVect
// 3. test for an error on a wrong formatted dimension
e = expectThrows(MapperParsingException.class, () -> {
mapper.parse(new SourceToParse("test-index", "_doc", "1", BytesReference
.bytes(XContentFactory.jsonBuilder()
.startObject()
.startObject("my-sparse-vector")
.field("WrongDim123", 100f)
.endObject()
.endObject()),
XContentType.JSON));
mapper.parse(source(b -> {
b.startObject("field");
b.field("WrongDim123", 100f);
b.endObject();
}));
});
assertThat(e.getCause(), instanceOf(IllegalArgumentException.class));
assertThat(e.getCause().getMessage(), containsString(
@ -228,14 +190,11 @@ public class SparseVectorFieldMapperTests extends FieldMapperTestCase<SparseVect
// 4. test for an error on a wrong format for the map of dims to values
e = expectThrows(MapperParsingException.class, () -> {
mapper.parse(new SourceToParse("test-index", "_doc", "1", BytesReference
.bytes(XContentFactory.jsonBuilder()
.startObject()
.startObject("my-sparse-vector")
.startArray(Integer.toString(10)).value(10f).value(100f).endArray()
.endObject()
.endObject()),
XContentType.JSON));
mapper.parse(source(b -> {
b.startObject("field");
b.startArray("10").value(10f).value(100f).endArray();
b.endObject();
}));
});
assertThat(e.getCause(), instanceOf(IllegalArgumentException.class));
assertThat(e.getCause().getMessage(), containsString(
@ -249,41 +208,29 @@ public class SparseVectorFieldMapperTests extends FieldMapperTestCase<SparseVect
.boxed()
.collect(Collectors.toMap(String::valueOf, Function.identity()));
BytesReference validDoc = BytesReference.bytes(
XContentFactory.jsonBuilder().startObject()
.field("my-sparse-vector", validVector)
.endObject());
mapper.parse(new SourceToParse("test-index", "_doc", "1", validDoc, XContentType.JSON));
DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping));
mapper.parse(source(b -> b.field("field", validVector)));
Map<String, Object> invalidVector = IntStream.range(0, SparseVectorFieldMapper.MAX_DIMS_COUNT + 1)
.boxed()
.collect(Collectors.toMap(String::valueOf, Function.identity()));
BytesReference invalidDoc = BytesReference.bytes(
XContentFactory.jsonBuilder().startObject()
.field("my-sparse-vector", invalidVector)
.endObject());
MapperParsingException e = expectThrows(MapperParsingException.class, () -> mapper.parse(
new SourceToParse("test-index", "_doc", "1", invalidDoc, XContentType.JSON)));
MapperParsingException e = expectThrows(MapperParsingException.class,
() -> mapper.parse(source(b -> b.field("field", invalidVector))));
assertThat(e.getDetailedMessage(), containsString("has exceeded the maximum allowed number of dimensions"));
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
}
@Override
protected SparseVectorFieldMapper.Builder newBuilder() {
return new SparseVectorFieldMapper.Builder("sparsevector");
public void testUpdates() throws IOException {
// no updates to test
}
@Override
public void testSerialization() throws IOException {
super.testSerialization();
assertWarnings("The [sparse_vector] field type is deprecated and will be removed in 8.0.");
}
@Override
public void testMergeConflicts() {
super.testMergeConflicts();
assertWarnings("The [sparse_vector] field type is deprecated and will be removed in 8.0.");
public void testMeta() throws IOException {
super.testMeta();
assertParseMinimalWarnings();
}
}