Fix an off-by-one error in the vector field dimension limit. (#40489)

Previously only vectors up to 499 dimensions were accepted, whereas the stated
limit is 500.
This commit is contained in:
Julie Tibshirani 2019-03-27 11:13:51 -07:00
parent 64b31f44af
commit 419cf1c02f
4 changed files with 69 additions and 18 deletions

View File

@ -169,10 +169,9 @@ public class DenseVectorFieldMapper extends FieldMapper implements ArrayValueMap
buf[offset+2] = (byte) (intValue >> 8); buf[offset+2] = (byte) (intValue >> 8);
buf[offset+3] = (byte) intValue; buf[offset+3] = (byte) intValue;
offset += INT_BYTES; offset += INT_BYTES;
dim++; if (dim++ >= MAX_DIMS_COUNT) {
if (dim >= MAX_DIMS_COUNT) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() +
"] has exceeded the maximum allowed number of dimensions of :[" + MAX_DIMS_COUNT + "]"); "] has exceeded the maximum allowed number of dimensions of [" + MAX_DIMS_COUNT + "]");
} }
} }
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(buf, 0, offset)); BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(buf, 0, offset));

View File

@ -178,10 +178,9 @@ public class SparseVectorFieldMapper extends FieldMapper {
} }
dims[dimCount] = dim; dims[dimCount] = dim;
values[dimCount] = value; values[dimCount] = value;
dimCount ++; if (dimCount++ >= MAX_DIMS_COUNT) {
if (dimCount >= MAX_DIMS_COUNT) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() +
"] has exceeded the maximum allowed number of dimensions of :[" + MAX_DIMS_COUNT + "]"); "] has exceeded the maximum allowed number of dimensions of [" + MAX_DIMS_COUNT + "]");
} }
} else { } else {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() +

View File

@ -30,18 +30,19 @@ import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.IndexService; import org.elasticsearch.index.IndexService;
import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.test.ESSingleNodeTestCase;
import org.hamcrest.Matchers; import org.junit.Before;
import java.io.IOException;
import java.util.Collection; import java.util.Collection;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
public class DenseVectorFieldMapperTests extends ESSingleNodeTestCase { public class DenseVectorFieldMapperTests extends ESSingleNodeTestCase {
private DocumentMapper mapper;
@Override @Before
protected Collection<Class<? extends Plugin>> getPlugins() { public void setUpMapper() throws Exception {
return pluginList(MapperExtrasPlugin.class);
}
public void testDefaults() throws Exception {
IndexService indexService = createIndex("test-index"); IndexService indexService = createIndex("test-index");
DocumentMapperParser parser = indexService.mapperService().documentMapperParser(); DocumentMapperParser parser = indexService.mapperService().documentMapperParser();
String mapping = Strings.toString(XContentFactory.jsonBuilder() String mapping = Strings.toString(XContentFactory.jsonBuilder()
@ -53,10 +54,15 @@ public class DenseVectorFieldMapperTests extends ESSingleNodeTestCase {
.endObject() .endObject()
.endObject() .endObject()
.endObject()); .endObject());
mapper = parser.parse("_doc", new CompressedXContent(mapping));
}
DocumentMapper mapper = parser.parse("_doc", new CompressedXContent(mapping)); @Override
assertEquals(mapping, mapper.mappingSource().toString()); protected Collection<Class<? extends Plugin>> getPlugins() {
return pluginList(MapperExtrasPlugin.class);
}
public void testDefaults() throws Exception {
float[] expectedArray = {-12.1f, 100.7f, -4}; float[] expectedArray = {-12.1f, 100.7f, -4};
ParsedDocument doc1 = mapper.parse(new SourceToParse("test-index", "_doc", "1", BytesReference ParsedDocument doc1 = mapper.parse(new SourceToParse("test-index", "_doc", "1", BytesReference
.bytes(XContentFactory.jsonBuilder() .bytes(XContentFactory.jsonBuilder()
@ -66,7 +72,7 @@ public class DenseVectorFieldMapperTests extends ESSingleNodeTestCase {
XContentType.JSON)); XContentType.JSON));
IndexableField[] fields = doc1.rootDoc().getFields("my-dense-vector"); IndexableField[] fields = doc1.rootDoc().getFields("my-dense-vector");
assertEquals(1, fields.length); assertEquals(1, fields.length);
assertThat(fields[0], Matchers.instanceOf(BinaryDocValuesField.class)); assertThat(fields[0], instanceOf(BinaryDocValuesField.class));
// assert that after decoding the indexed value is equal to expected // assert that after decoding the indexed value is equal to expected
BytesRef vectorBR = ((BinaryDocValuesField) fields[0]).binaryValue(); BytesRef vectorBR = ((BinaryDocValuesField) fields[0]).binaryValue();
@ -78,4 +84,22 @@ public class DenseVectorFieldMapperTests extends ESSingleNodeTestCase {
0.001f 0.001f
); );
} }
public void testDimensionLimit() throws IOException {
float[] validVector = new float[DenseVectorFieldMapper.MAX_DIMS_COUNT];
BytesReference validDoc = BytesReference.bytes(
XContentFactory.jsonBuilder().startObject()
.array("my-dense-vector", validVector)
.endObject());
mapper.parse(new SourceToParse("test-index", "_doc", "1", validDoc, XContentType.JSON));
float[] invalidVector = new float[DenseVectorFieldMapper.MAX_DIMS_COUNT + 1];
BytesReference invalidDoc = BytesReference.bytes(
XContentFactory.jsonBuilder().startObject()
.array("my-dense-vector", invalidVector)
.endObject());
MapperParsingException e = expectThrows(MapperParsingException.class, () -> mapper.parse(
new SourceToParse("test-index", "_doc", "1", invalidDoc, XContentType.JSON)));
assertThat(e.getDetailedMessage(), containsString("has exceeded the maximum allowed number of dimensions"));
}
} }

View File

@ -33,7 +33,12 @@ import org.elasticsearch.test.ESSingleNodeTestCase;
import org.hamcrest.Matchers; import org.hamcrest.Matchers;
import org.junit.Before; import org.junit.Before;
import java.io.IOException;
import java.util.Collection; import java.util.Collection;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.core.IsInstanceOf.instanceOf; import static org.hamcrest.core.IsInstanceOf.instanceOf;
@ -42,7 +47,7 @@ public class SparseVectorFieldMapperTests extends ESSingleNodeTestCase {
private DocumentMapper mapper; private DocumentMapper mapper;
@Before @Before
public void setup() throws Exception { public void setUpMapper() throws Exception {
IndexService indexService = createIndex("test-index"); IndexService indexService = createIndex("test-index");
DocumentMapperParser parser = indexService.mapperService().documentMapperParser(); DocumentMapperParser parser = indexService.mapperService().documentMapperParser();
String mapping = Strings.toString(XContentFactory.jsonBuilder() String mapping = Strings.toString(XContentFactory.jsonBuilder()
@ -100,7 +105,7 @@ public class SparseVectorFieldMapperTests extends ESSingleNodeTestCase {
); );
} }
public void testErrors() { public void testDimensionNumberValidation() {
// 1. test for an error on negative dimension // 1. test for an error on negative dimension
MapperParsingException e = expectThrows(MapperParsingException.class, () -> { MapperParsingException e = expectThrows(MapperParsingException.class, () -> {
mapper.parse(new SourceToParse("test-index", "_doc", "1", BytesReference mapper.parse(new SourceToParse("test-index", "_doc", "1", BytesReference
@ -161,4 +166,28 @@ public class SparseVectorFieldMapperTests extends ESSingleNodeTestCase {
assertThat(e.getCause().getMessage(), containsString( assertThat(e.getCause().getMessage(), containsString(
"takes an object that maps a dimension number to a float, but got unexpected token [START_ARRAY]")); "takes an object that maps a dimension number to a float, but got unexpected token [START_ARRAY]"));
} }
public void testDimensionLimit() throws IOException {
Map<String, Object> validVector = IntStream.range(0, SparseVectorFieldMapper.MAX_DIMS_COUNT)
.boxed()
.collect(Collectors.toMap(String::valueOf, Function.identity()));
BytesReference validDoc = BytesReference.bytes(
XContentFactory.jsonBuilder().startObject()
.field("my-sparse-vector", validVector)
.endObject());
mapper.parse(new SourceToParse("test-index", "_doc", "1", validDoc, XContentType.JSON));
Map<String, Object> invalidVector = IntStream.range(0, SparseVectorFieldMapper.MAX_DIMS_COUNT + 1)
.boxed()
.collect(Collectors.toMap(String::valueOf, Function.identity()));
BytesReference invalidDoc = BytesReference.bytes(
XContentFactory.jsonBuilder().startObject()
.field("my-sparse-vector", invalidVector)
.endObject());
MapperParsingException e = expectThrows(MapperParsingException.class, () -> mapper.parse(
new SourceToParse("test-index", "_doc", "1", invalidDoc, XContentType.JSON)));
assertThat(e.getDetailedMessage(), containsString("has exceeded the maximum allowed number of dimensions"));
}
} }