Add a MemorySegment Vector scorer - for scoring without copying on-heap (#13339)

Add a MemorySegment Vector scorer - for scoring without copying on-heap.

The vector scorer loads values directly from the backing memory segment when available. Otherwise, if the vector data spans across segments the scorer copies the vector data on-heap.

A benchmark shows ~2x performance improvement of this scorer over the default copy-on-heap scorer.

The scorer currently only operates on vectors with an element size of byte. We can evaluate if and how to support floats separately.
This commit is contained in:
Chris Hegarty 2024-05-21 17:34:37 +01:00 committed by GitHub
parent f70999980c
commit 05f04aa08a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 1350 additions and 71 deletions

View File

@ -348,6 +348,8 @@ Optimizations
* GITHUB#13392: Replace Map<Long, Object> by primitive LongObjectHashMap. (Bruno Roustant)
* GITHUB#13339: Add a MemorySegment Vector scorer - for scoring without copying on-heap (Chris Hegarty)
Bug Fixes
---------------------

View File

@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.benchmark.jmh;
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
import java.io.IOException;
import java.nio.file.Files;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.openjdk.jmh.annotations.*;
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Benchmark)
// first iteration is complete garbage, so make sure we really warmup
@Warmup(iterations = 4, time = 1)
// real iterations. not useful to spend tons of time here, better to fork more
@Measurement(iterations = 5, time = 1)
// engage some noise reduction
@Fork(
value = 3,
jvmArgsAppend = {"-Xmx2g", "-Xms2g", "-XX:+AlwaysPreTouch"})
public class VectorScorerBenchmark {
@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
int size;
Directory dir;
IndexInput in;
RandomAccessVectorValues vectorValues;
byte[] vec1, vec2;
RandomVectorScorer scorer;
@Setup(Level.Iteration)
public void init() throws IOException {
vec1 = new byte[size];
vec2 = new byte[size];
ThreadLocalRandom.current().nextBytes(vec1);
ThreadLocalRandom.current().nextBytes(vec2);
dir = new MMapDirectory(Files.createTempDirectory("VectorScorerBenchmark"));
try (IndexOutput out = dir.createOutput("vector.data", IOContext.DEFAULT)) {
out.writeBytes(vec1, 0, vec1.length);
out.writeBytes(vec2, 0, vec2.length);
}
in = dir.openInput("vector.data", IOContext.DEFAULT);
vectorValues = vectorValues(size, 2, in, DOT_PRODUCT);
scorer =
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
.getRandomVectorScorerSupplier(DOT_PRODUCT, vectorValues)
.scorer(0);
}
@TearDown
public void teardown() throws IOException {
IOUtils.close(dir, in);
}
@Benchmark
public float binaryDotProductDefault() throws IOException {
return scorer.score(1);
}
@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public float binaryDotProductMemSeg() throws IOException {
return scorer.score(1);
}
static RandomAccessVectorValues vectorValues(
int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
return new OffHeapByteVectorValues.DenseOffHeapVectorValues(
dims, size, in.slice("test", 0, in.length()), dims, new ThrowingFlatVectorScorer(), sim);
}
static final class ThrowingFlatVectorScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) {
throw new UnsupportedOperationException();
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
float[] target) {
throw new UnsupportedOperationException();
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
RandomAccessVectorValues vectorValues,
byte[] target) {
throw new UnsupportedOperationException();
}
}
}

View File

@ -29,6 +29,9 @@ import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
* @lucene.experimental
*/
public class DefaultFlatVectorScorer implements FlatVectorsScorer {
public static final DefaultFlatVectorScorer INSTANCE = new DefaultFlatVectorScorer();
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)

View File

@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.hnsw;
import org.apache.lucene.internal.vectorization.VectorizationProvider;
/**
* Utilities for {@link FlatVectorsScorer}.
*
* @lucene.experimental
*/
public final class FlatVectorScorerUtil {
private static final VectorizationProvider IMPL = VectorizationProvider.getInstance();
private FlatVectorScorerUtil() {}
/**
* Returns a FlatVectorsScorer that supports the Lucene99 format. Scorers retrieved through this
* method may be optimized on certain platforms. Otherwise, a DefaultFlatVectorScorer is returned.
*/
public static FlatVectorsScorer getLucene99FlatVectorsScorer() {
return IMPL.getLucene99FlatVectorsScorer();
}
}

View File

@ -22,7 +22,7 @@ import java.util.concurrent.ExecutorService;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.MergePolicy;
@ -139,7 +139,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
/** The format for storing, reading, merging vectors on disk */
private static final FlatVectorsFormat flatVectorsFormat =
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
private final int numMergeWorkers;
private final TaskExecutor mergeExec;

View File

@ -19,6 +19,7 @@ package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
@ -48,7 +49,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
static final String VECTOR_DATA_EXTENSION = "veq";
private static final FlatVectorsFormat rawVectorFormat =
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
/** The minimum confidence interval */
private static final float MINIMUM_CONFIDENCE_INTERVAL = 0.9f;
@ -101,7 +102,8 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
this.bits = (byte) bits;
this.confidenceInterval = confidenceInterval;
this.compress = compress;
this.flatVectorScorer = new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer());
this.flatVectorScorer =
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
}
public static float calculateDefaultConfidenceInterval(int vectorDimension) {

View File

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.internal.tests;
import org.apache.lucene.store.FilterIndexInput;
/**
* Access to {@link org.apache.lucene.store.FilterIndexInput} internals exposed to the test
* framework.
*
* @lucene.internal
*/
public interface FilterIndexInputAccess {
/** Adds the given test FilterIndexInput class. */
void addTestFilterType(Class<? extends FilterIndexInput> cls);
}

View File

@ -23,6 +23,7 @@ import org.apache.lucene.index.ConcurrentMergeScheduler;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.store.FilterIndexInput;
/**
* A set of static methods returning accessors for internal, package-private functionality in
@ -48,12 +49,14 @@ public final class TestSecrets {
ensureInitialized.accept(ConcurrentMergeScheduler.class);
ensureInitialized.accept(SegmentReader.class);
ensureInitialized.accept(IndexWriter.class);
ensureInitialized.accept(FilterIndexInput.class);
}
private static IndexPackageAccess indexPackageAccess;
private static ConcurrentMergeSchedulerAccess cmsAccess;
private static SegmentReaderAccess segmentReaderAccess;
private static IndexWriterAccess indexWriterAccess;
private static FilterIndexInputAccess filterIndexInputAccess;
private TestSecrets() {}
@ -81,6 +84,12 @@ public final class TestSecrets {
return Objects.requireNonNull(indexWriterAccess);
}
/** Return the accessor to internal secrets for an {@link FilterIndexInput}. */
public static FilterIndexInputAccess getFilterInputIndexAccess() {
ensureCaller();
return Objects.requireNonNull(filterIndexInputAccess);
}
/** For internal initialization only. */
public static void setIndexWriterAccess(IndexWriterAccess indexWriterAccess) {
ensureNull(TestSecrets.indexWriterAccess);
@ -105,6 +114,12 @@ public final class TestSecrets {
TestSecrets.segmentReaderAccess = segmentReaderAccess;
}
/** For internal initialization only. */
public static void setFilterInputIndexAccess(FilterIndexInputAccess filterIndexInputAccess) {
ensureNull(TestSecrets.filterIndexInputAccess);
TestSecrets.filterIndexInputAccess = filterIndexInputAccess;
}
private static void ensureNull(Object ob) {
if (ob != null) {
throw new AssertionError(

View File

@ -17,6 +17,9 @@
package org.apache.lucene.internal.vectorization;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
/** Default provider returning scalar implementations. */
final class DefaultVectorizationProvider extends VectorizationProvider {
@ -30,4 +33,9 @@ final class DefaultVectorizationProvider extends VectorizationProvider {
public VectorUtilSupport getVectorUtilSupport() {
return vectorUtilSupport;
}
@Override
public FlatVectorsScorer getLucene99FlatVectorsScorer() {
return DefaultFlatVectorScorer.INSTANCE;
}
}

View File

@ -27,6 +27,7 @@ import java.util.Set;
import java.util.function.Predicate;
import java.util.logging.Logger;
import java.util.stream.Stream;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.VectorUtil;
@ -91,6 +92,9 @@ public abstract class VectorizationProvider {
*/
public abstract VectorUtilSupport getVectorUtilSupport();
/** Returns a FlatVectorsScorer that supports the Lucene99 format. */
public abstract FlatVectorsScorer getLucene99FlatVectorsScorer();
// *** Lookup mechanism: ***
private static final Logger LOG = Logger.getLogger(VectorizationProvider.class.getName());
@ -177,7 +181,10 @@ public abstract class VectorizationProvider {
}
// add all possible callers here as FQCN:
private static final Set<String> VALID_CALLERS = Set.of("org.apache.lucene.util.VectorUtil");
private static final Set<String> VALID_CALLERS =
Set.of(
"org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil",
"org.apache.lucene.util.VectorUtil");
private static void ensureCaller() {
final boolean validCaller =

View File

@ -17,6 +17,8 @@
package org.apache.lucene.store;
import java.io.IOException;
import java.util.concurrent.CopyOnWriteArrayList;
import org.apache.lucene.internal.tests.TestSecrets;
/**
* IndexInput implementation that delegates calls to another directory. This class can be used to
@ -29,6 +31,12 @@ import java.io.IOException;
*/
public class FilterIndexInput extends IndexInput {
static final CopyOnWriteArrayList<Class<?>> TEST_FILTER_INPUTS = new CopyOnWriteArrayList<>();
static {
TestSecrets.setFilterInputIndexAccess(TEST_FILTER_INPUTS::add);
}
/**
* Unwraps all FilterIndexInputs until the first non-FilterIndexInput IndexInput instance and
* returns it
@ -40,6 +48,17 @@ public class FilterIndexInput extends IndexInput {
return in;
}
/**
* Unwraps all test FilterIndexInputs until the first non-test FilterIndexInput IndexInput
* instance and returns it
*/
public static IndexInput unwrapOnlyTest(IndexInput in) {
while (in instanceof FilterIndexInput && TEST_FILTER_INPUTS.contains(in.getClass())) {
in = ((FilterIndexInput) in).in;
}
return in;
}
protected final IndexInput in;
/** Creates a FilterIndexInput with a resource description and wrapped delegate IndexInput */

View File

@ -0,0 +1,151 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.internal.vectorization;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.util.Optional;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.FilterIndexInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
abstract sealed class Lucene99MemorySegmentByteVectorScorer
extends RandomVectorScorer.AbstractRandomVectorScorer {
final int vectorByteSize;
final MemorySegmentAccessInput input;
final MemorySegment query;
byte[] scratch;
/**
* Return an optional whose value, if present, is the scorer. Otherwise, an empty optional is
* returned.
*/
public static Optional<Lucene99MemorySegmentByteVectorScorer> create(
VectorSimilarityFunction type,
IndexInput input,
RandomAccessVectorValues values,
byte[] queryVector) {
input = FilterIndexInput.unwrapOnlyTest(input);
if (!(input instanceof MemorySegmentAccessInput msInput)) {
return Optional.empty();
}
checkInvariants(values.size(), values.getVectorByteLength(), input);
return switch (type) {
case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector));
case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector));
case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector));
case MAXIMUM_INNER_PRODUCT -> Optional.of(
new MaxInnerProductScorer(msInput, values, queryVector));
};
}
Lucene99MemorySegmentByteVectorScorer(
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] queryVector) {
super(values);
this.input = input;
this.vectorByteSize = values.getVectorByteLength();
this.query = MemorySegment.ofArray(queryVector);
}
final MemorySegment getSegment(int ord) throws IOException {
checkOrdinal(ord);
long byteOffset = (long) ord * vectorByteSize;
MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize);
if (seg == null) {
if (scratch == null) {
scratch = new byte[vectorByteSize];
}
input.readBytes(byteOffset, scratch, 0, vectorByteSize);
seg = MemorySegment.ofArray(scratch);
}
return seg;
}
static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
if (input.length() < (long) vectorByteLength * maxOrd) {
throw new IllegalArgumentException("input length is less than expected vector data");
}
}
final void checkOrdinal(int ord) {
if (ord < 0 || ord >= maxOrd()) {
throw new IllegalArgumentException("illegal ordinal: " + ord);
}
}
static final class CosineScorer extends Lucene99MemorySegmentByteVectorScorer {
CosineScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
super(input, values, query);
}
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
float raw = PanamaVectorUtilSupport.cosine(query, getSegment(node));
return (1 + raw) / 2;
}
}
static final class DotProductScorer extends Lucene99MemorySegmentByteVectorScorer {
DotProductScorer(
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
super(input, values, query);
}
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
return 0.5f + raw / (float) (query.byteSize() * (1 << 15));
}
}
static final class EuclideanScorer extends Lucene99MemorySegmentByteVectorScorer {
EuclideanScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
super(input, values, query);
}
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
float raw = PanamaVectorUtilSupport.squareDistance(query, getSegment(node));
return 1 / (1f + raw);
}
}
static final class MaxInnerProductScorer extends Lucene99MemorySegmentByteVectorScorer {
MaxInnerProductScorer(
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
super(input, values, query);
}
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
if (raw < 0) {
return 1 / (1 + -1 * raw);
}
return raw + 1;
}
}
}

View File

@ -0,0 +1,210 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.internal.vectorization;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.util.Optional;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.FilterIndexInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
/** A score supplier of vectors whose element size is byte. */
public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
implements RandomVectorScorerSupplier {
final int vectorByteSize;
final int maxOrd;
final MemorySegmentAccessInput input;
final RandomAccessVectorValues values; // to support ordToDoc/getAcceptOrds
byte[] scratch1, scratch2;
/**
* Return an optional whose value, if present, is the scorer supplier. Otherwise, an empty
* optional is returned.
*/
static Optional<RandomVectorScorerSupplier> create(
VectorSimilarityFunction type, IndexInput input, RandomAccessVectorValues values) {
input = FilterIndexInput.unwrapOnlyTest(input);
if (!(input instanceof MemorySegmentAccessInput msInput)) {
return Optional.empty();
}
checkInvariants(values.size(), values.getVectorByteLength(), input);
return switch (type) {
case COSINE -> Optional.of(new CosineSupplier(msInput, values));
case DOT_PRODUCT -> Optional.of(new DotProductSupplier(msInput, values));
case EUCLIDEAN -> Optional.of(new EuclideanSupplier(msInput, values));
case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(msInput, values));
};
}
Lucene99MemorySegmentByteVectorScorerSupplier(
MemorySegmentAccessInput input, RandomAccessVectorValues values) {
this.input = input;
this.values = values;
this.vectorByteSize = values.getVectorByteLength();
this.maxOrd = values.size();
}
static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
if (input.length() < (long) vectorByteLength * maxOrd) {
throw new IllegalArgumentException("input length is less than expected vector data");
}
}
final void checkOrdinal(int ord) {
if (ord < 0 || ord >= maxOrd) {
throw new IllegalArgumentException("illegal ordinal: " + ord);
}
}
final MemorySegment getFirstSegment(int ord) throws IOException {
long byteOffset = (long) ord * vectorByteSize;
MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize);
if (seg == null) {
if (scratch1 == null) {
scratch1 = new byte[vectorByteSize];
}
input.readBytes(byteOffset, scratch1, 0, vectorByteSize);
seg = MemorySegment.ofArray(scratch1);
}
return seg;
}
final MemorySegment getSecondSegment(int ord) throws IOException {
long byteOffset = (long) ord * vectorByteSize;
MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize);
if (seg == null) {
if (scratch2 == null) {
scratch2 = new byte[vectorByteSize];
}
input.readBytes(byteOffset, scratch2, 0, vectorByteSize);
seg = MemorySegment.ofArray(scratch2);
}
return seg;
}
static final class CosineSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
CosineSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
super(input, values);
}
@Override
public RandomVectorScorer scorer(int ord) {
checkOrdinal(ord);
return new RandomVectorScorer.AbstractRandomVectorScorer(values) {
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
float raw = PanamaVectorUtilSupport.cosine(getFirstSegment(ord), getSecondSegment(node));
return (1 + raw) / 2;
}
};
}
@Override
public CosineSupplier copy() throws IOException {
return new CosineSupplier(input.clone(), values);
}
}
static final class DotProductSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
DotProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
super(input, values);
}
@Override
public RandomVectorScorer scorer(int ord) {
checkOrdinal(ord);
return new RandomVectorScorer.AbstractRandomVectorScorer(values) {
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
float raw =
PanamaVectorUtilSupport.dotProduct(getFirstSegment(ord), getSecondSegment(node));
return 0.5f + raw / (float) (values.dimension() * (1 << 15));
}
};
}
@Override
public DotProductSupplier copy() throws IOException {
return new DotProductSupplier(input.clone(), values);
}
}
static final class EuclideanSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
EuclideanSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
super(input, values);
}
@Override
public RandomVectorScorer scorer(int ord) {
checkOrdinal(ord);
return new RandomVectorScorer.AbstractRandomVectorScorer(values) {
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
float raw =
PanamaVectorUtilSupport.squareDistance(getFirstSegment(ord), getSecondSegment(node));
return 1 / (1f + raw);
}
};
}
@Override
public EuclideanSupplier copy() throws IOException {
return new EuclideanSupplier(input.clone(), values);
}
}
static final class MaxInnerProductSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
MaxInnerProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
super(input, values);
}
@Override
public RandomVectorScorer scorer(int ord) {
checkOrdinal(ord);
return new RandomVectorScorer.AbstractRandomVectorScorer(values) {
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
float raw =
PanamaVectorUtilSupport.dotProduct(getFirstSegment(ord), getSecondSegment(node));
if (raw < 0) {
return 1 / (1 + -1 * raw);
}
return raw + 1;
}
};
}
@Override
public MaxInnerProductSupplier copy() throws IOException {
return new MaxInnerProductSupplier(input.clone(), values);
}
}
}

View File

@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.internal.vectorization;
import java.io.IOException;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
public class Lucene99MemorySegmentFlatVectorsScorer implements FlatVectorsScorer {
public static final Lucene99MemorySegmentFlatVectorsScorer INSTANCE =
new Lucene99MemorySegmentFlatVectorsScorer(DefaultFlatVectorScorer.INSTANCE);
private final FlatVectorsScorer delegate;
private Lucene99MemorySegmentFlatVectorsScorer(FlatVectorsScorer delegate) {
this.delegate = delegate;
}
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityType, RandomAccessVectorValues vectorValues)
throws IOException {
// currently only supports binary vectors
if (vectorValues instanceof RandomAccessVectorValues.Bytes && vectorValues.getSlice() != null) {
var scorer =
Lucene99MemorySegmentByteVectorScorerSupplier.create(
similarityType, vectorValues.getSlice(), vectorValues);
if (scorer.isPresent()) {
return scorer.get();
}
}
return delegate.getRandomVectorScorerSupplier(similarityType, vectorValues);
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityType,
RandomAccessVectorValues vectorValues,
float[] target)
throws IOException {
// currently only supports binary vectors, so always delegate
return delegate.getRandomVectorScorer(similarityType, vectorValues, target);
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityType,
RandomAccessVectorValues vectorValues,
byte[] queryVector)
throws IOException {
checkDimensions(queryVector.length, vectorValues.dimension());
if (vectorValues instanceof RandomAccessVectorValues.Bytes && vectorValues.getSlice() != null) {
var scorer =
Lucene99MemorySegmentByteVectorScorer.create(
similarityType, vectorValues.getSlice(), vectorValues, queryVector);
if (scorer.isPresent()) {
return scorer.get();
}
}
return delegate.getRandomVectorScorer(similarityType, vectorValues, queryVector);
}
static void checkDimensions(int queryLen, int fieldLen) {
if (queryLen != fieldLen) {
throw new IllegalArgumentException(
"vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen);
}
}
@Override
public String toString() {
return "Lucene99MemorySegmentFlatVectorsScorer()";
}
}

View File

@ -16,6 +16,8 @@
*/
package org.apache.lucene.internal.vectorization;
import static java.lang.foreign.ValueLayout.JAVA_BYTE;
import static java.nio.ByteOrder.LITTLE_ENDIAN;
import static jdk.incubator.vector.VectorOperators.ADD;
import static jdk.incubator.vector.VectorOperators.B2I;
import static jdk.incubator.vector.VectorOperators.B2S;
@ -23,6 +25,7 @@ import static jdk.incubator.vector.VectorOperators.LSHR;
import static jdk.incubator.vector.VectorOperators.S2I;
import static jdk.incubator.vector.VectorOperators.ZERO_EXTEND_B2S;
import java.lang.foreign.MemorySegment;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.IntVector;
@ -307,39 +310,44 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
@Override
public int dotProduct(byte[] a, byte[] b) {
return dotProduct(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
}
public static int dotProduct(MemorySegment a, MemorySegment b) {
assert a.byteSize() == b.byteSize();
int i = 0;
int res = 0;
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
// vectors (256-bit on intel to dodge performance landmines)
if (a.length >= 16 && HAS_FAST_INTEGER_VECTORS) {
if (a.byteSize() >= 16 && HAS_FAST_INTEGER_VECTORS) {
// compute vectorized dot product consistent with VPDPBUSD instruction
if (VECTOR_BITSIZE >= 512) {
i += BYTE_SPECIES.loopBound(a.length);
i += BYTE_SPECIES.loopBound(a.byteSize());
res += dotProductBody512(a, b, i);
} else if (VECTOR_BITSIZE == 256) {
i += BYTE_SPECIES.loopBound(a.length);
i += BYTE_SPECIES.loopBound(a.byteSize());
res += dotProductBody256(a, b, i);
} else {
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
i += ByteVector.SPECIES_64.loopBound(a.length - ByteVector.SPECIES_64.length());
i += ByteVector.SPECIES_64.loopBound(a.byteSize() - ByteVector.SPECIES_64.length());
res += dotProductBody128(a, b, i);
}
}
// scalar tail
for (; i < a.length; i++) {
res += b[i] * a[i];
for (; i < a.byteSize(); i++) {
res += b.get(JAVA_BYTE, i) * a.get(JAVA_BYTE, i);
}
return res;
}
/** vectorized dot product body (512 bit vectors) */
private int dotProductBody512(byte[] a, byte[] b, int limit) {
private static int dotProductBody512(MemorySegment a, MemorySegment b, int limit) {
IntVector acc = IntVector.zero(INT_SPECIES);
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN);
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
@ -355,11 +363,11 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
}
/** vectorized dot product body (256 bit vectors) */
private int dotProductBody256(byte[] a, byte[] b, int limit) {
private static int dotProductBody256(MemorySegment a, MemorySegment b, int limit) {
IntVector acc = IntVector.zero(IntVector.SPECIES_256);
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
// 32-bit multiply and add into accumulator
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0);
@ -371,13 +379,13 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
}
/** vectorized dot product body (128 bit vectors) */
private int dotProductBody128(byte[] a, byte[] b, int limit) {
private static int dotProductBody128(MemorySegment a, MemorySegment b, int limit) {
IntVector acc = IntVector.zero(IntVector.SPECIES_128);
// 4 bytes at a time (re-loading half the vector each time!)
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
// load 8 bytes
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
// process first "half" only: 16-bit multiply
Vector<Short> va16 = va8.convert(B2S, 0);
@ -569,6 +577,10 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
@Override
public float cosine(byte[] a, byte[] b) {
return cosine(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
}
public static float cosine(MemorySegment a, MemorySegment b) {
int i = 0;
int sum = 0;
int norm1 = 0;
@ -576,17 +588,17 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
// vectors (256-bit on intel to dodge performance landmines)
if (a.length >= 16 && HAS_FAST_INTEGER_VECTORS) {
if (a.byteSize() >= 16 && HAS_FAST_INTEGER_VECTORS) {
final float[] ret;
if (VECTOR_BITSIZE >= 512) {
i += BYTE_SPECIES.loopBound(a.length);
i += BYTE_SPECIES.loopBound((int) a.byteSize());
ret = cosineBody512(a, b, i);
} else if (VECTOR_BITSIZE == 256) {
i += BYTE_SPECIES.loopBound(a.length);
i += BYTE_SPECIES.loopBound((int) a.byteSize());
ret = cosineBody256(a, b, i);
} else {
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
i += ByteVector.SPECIES_64.loopBound(a.length - ByteVector.SPECIES_64.length());
i += ByteVector.SPECIES_64.loopBound(a.byteSize() - ByteVector.SPECIES_64.length());
ret = cosineBody128(a, b, i);
}
sum += ret[0];
@ -595,9 +607,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
}
// scalar tail
for (; i < a.length; i++) {
byte elem1 = a[i];
byte elem2 = b[i];
for (; i < a.byteSize(); i++) {
byte elem1 = a.get(JAVA_BYTE, i);
byte elem2 = b.get(JAVA_BYTE, i);
sum += elem1 * elem2;
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
@ -606,13 +618,13 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
}
/** vectorized cosine body (512 bit vectors) */
private float[] cosineBody512(byte[] a, byte[] b, int limit) {
private static float[] cosineBody512(MemorySegment a, MemorySegment b, int limit) {
IntVector accSum = IntVector.zero(INT_SPECIES);
IntVector accNorm1 = IntVector.zero(INT_SPECIES);
IntVector accNorm2 = IntVector.zero(INT_SPECIES);
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN);
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
@ -636,13 +648,13 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
}
/** vectorized cosine body (256 bit vectors) */
private float[] cosineBody256(byte[] a, byte[] b, int limit) {
private static float[] cosineBody256(MemorySegment a, MemorySegment b, int limit) {
IntVector accSum = IntVector.zero(IntVector.SPECIES_256);
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_256);
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_256);
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
// 16-bit multiply, and add into accumulators
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0);
@ -661,13 +673,13 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
}
/** vectorized cosine body (128 bit vectors) */
private float[] cosineBody128(byte[] a, byte[] b, int limit) {
private static float[] cosineBody128(MemorySegment a, MemorySegment b, int limit) {
IntVector accSum = IntVector.zero(IntVector.SPECIES_128);
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_128);
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_128);
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
// process first half only: 16-bit multiply
Vector<Short> va16 = va8.convert(B2S, 0);
@ -689,35 +701,40 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
@Override
public int squareDistance(byte[] a, byte[] b) {
return squareDistance(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
}
public static int squareDistance(MemorySegment a, MemorySegment b) {
assert a.byteSize() == b.byteSize();
int i = 0;
int res = 0;
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
// vectors (256-bit on intel to dodge performance landmines)
if (a.length >= 16 && HAS_FAST_INTEGER_VECTORS) {
if (a.byteSize() >= 16 && HAS_FAST_INTEGER_VECTORS) {
if (VECTOR_BITSIZE >= 256) {
i += BYTE_SPECIES.loopBound(a.length);
i += BYTE_SPECIES.loopBound((int) a.byteSize());
res += squareDistanceBody256(a, b, i);
} else {
i += ByteVector.SPECIES_64.loopBound(a.length);
i += ByteVector.SPECIES_64.loopBound((int) a.byteSize());
res += squareDistanceBody128(a, b, i);
}
}
// scalar tail
for (; i < a.length; i++) {
int diff = a[i] - b[i];
for (; i < a.byteSize(); i++) {
int diff = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i);
res += diff * diff;
}
return res;
}
/** vectorized square distance body (256+ bit vectors) */
private int squareDistanceBody256(byte[] a, byte[] b, int limit) {
private static int squareDistanceBody256(MemorySegment a, MemorySegment b, int limit) {
IntVector acc = IntVector.zero(INT_SPECIES);
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN);
// 32-bit sub, multiply, and add into accumulators
// TODO: uses AVX-512 heavy multiply on zmm, should we just use 256-bit vectors on AVX-512?
@ -731,14 +748,14 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
}
/** vectorized square distance body (128 bit vectors) */
private int squareDistanceBody128(byte[] a, byte[] b, int limit) {
private static int squareDistanceBody128(MemorySegment a, MemorySegment b, int limit) {
// 128-bit implementation, which must "split up" vectors due to widening conversions
// it doesn't help to do the overlapping read trick, due to 32-bit multiply in the formula
IntVector acc1 = IntVector.zero(IntVector.SPECIES_128);
IntVector acc2 = IntVector.zero(IntVector.SPECIES_128);
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
// 16-bit sub
Vector<Short> va16 = va8.convertShape(B2S, ShortVector.SPECIES_128, 0);

View File

@ -21,6 +21,7 @@ import java.security.PrivilegedAction;
import java.util.Locale;
import java.util.logging.Logger;
import jdk.incubator.vector.FloatVector;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.SuppressForbidden;
@ -73,4 +74,9 @@ final class PanamaVectorizationProvider extends VectorizationProvider {
public VectorUtilSupport getVectorUtilSupport() {
return vectorUtilSupport;
}
@Override
public FlatVectorsScorer getLucene99FlatVectorsScorer() {
return Lucene99MemorySegmentFlatVectorsScorer.INSTANCE;
}
}

View File

@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.store;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
/**
* Provides access to the backing memory segment.
*
* <p>Expert API, allows access to the backing memory.
*/
public interface MemorySegmentAccessInput extends RandomAccessInput, Cloneable {
/** Returns the memory segment for a given position and length, or null. */
MemorySegment segmentSliceOrNull(long pos, int len) throws IOException;
MemorySegmentAccessInput clone();
}

View File

@ -36,7 +36,8 @@ import org.apache.lucene.util.GroupVIntUtil;
* chunkSizePower</code>).
*/
@SuppressWarnings("preview")
abstract class MemorySegmentIndexInput extends IndexInput implements RandomAccessInput {
abstract class MemorySegmentIndexInput extends IndexInput
implements RandomAccessInput, MemorySegmentAccessInput {
static final ValueLayout.OfByte LAYOUT_BYTE = ValueLayout.JAVA_BYTE;
static final ValueLayout.OfShort LAYOUT_LE_SHORT =
ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
@ -562,6 +563,10 @@ abstract class MemorySegmentIndexInput extends IndexInput implements RandomAcces
}
}
static boolean checkIndex(long index, long length) {
return index >= 0 && index < length;
}
@Override
public final void close() throws IOException {
if (curSegment == null) {
@ -673,6 +678,16 @@ abstract class MemorySegmentIndexInput extends IndexInput implements RandomAcces
throw alreadyClosed(e);
}
}
@Override
public MemorySegment segmentSliceOrNull(long pos, int len) throws IOException {
try {
Objects.checkIndex(pos + len, this.length + 1);
return curSegment.asSlice(pos, len);
} catch (IndexOutOfBoundsException e) {
throw handlePositionalIOOBE(e, "segmentSliceOrNull", pos);
}
}
}
/** This class adds offset support to MemorySegmentIndexInput, which is needed for slices. */
@ -738,6 +753,20 @@ abstract class MemorySegmentIndexInput extends IndexInput implements RandomAcces
return super.readLong(pos + offset);
}
public MemorySegment segmentSliceOrNull(long pos, int len) throws IOException {
if (pos + len > length) {
throw handlePositionalIOOBE(null, "segmentSliceOrNull", pos);
}
pos = pos + offset;
final int si = (int) (pos >> chunkSizePower);
final MemorySegment seg = segments[si];
final long segOffset = pos & chunkSizeMask;
if (checkIndex(segOffset + len, seg.byteSize() + 1)) {
return seg.asSlice(segOffset, len);
}
return null;
}
@Override
MemorySegmentIndexInput buildSlice(String sliceDescription, long ofs, long length) {
return super.buildSlice(sliceDescription, this.offset + ofs, length);

View File

@ -21,6 +21,8 @@ import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.oneOf;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import java.io.ByteArrayOutputStream;
@ -60,8 +62,9 @@ public class TestFlatVectorScorer extends LuceneTestCase {
public static Iterable<Object[]> parametersFactory() {
var scorers =
List.of(
new DefaultFlatVectorScorer(),
new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()));
DefaultFlatVectorScorer.INSTANCE,
new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()),
FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
var dirs =
List.<ThrowingSupplier<Directory>>of(
TestFlatVectorScorer::newDirectory,
@ -76,7 +79,14 @@ public class TestFlatVectorScorer extends LuceneTestCase {
return objs;
}
// Tests that the creation of another scorer does not perturb previous scorers
public void testDefaultOrMemSegScorer() {
var scorer = FlatVectorScorerUtil.getLucene99FlatVectorsScorer();
assertThat(
scorer.toString(),
is(oneOf("DefaultFlatVectorScorer()", "Lucene99MemorySegmentFlatVectorsScorer()")));
}
// Tests that the creation of another scorer does not disturb previous scorers
public void testMultipleByteScorers() throws IOException {
byte[] vec0 = new byte[] {0, 0, 0, 0};
byte[] vec1 = new byte[] {1, 1, 1, 1};

View File

@ -16,11 +16,15 @@
*/
package org.apache.lucene.codecs.lucene99;
import static java.lang.String.format;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.oneOf;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
@ -219,9 +223,12 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
10, 20, 1, (byte) 4, false, 0.9f, null);
}
};
String expectedString =
"Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer()), rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer())))";
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
String expectedPattern =
"Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer()), rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s())))";
var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
var memSegScorer =
format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
}
public void testLimits() {

View File

@ -16,6 +16,11 @@
*/
package org.apache.lucene.codecs.lucene99;
import static java.lang.String.format;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.oneOf;
import java.util.Locale;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
@ -37,9 +42,12 @@ public class TestLucene99HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
return new Lucene99HnswVectorsFormat(10, 20);
}
};
String expectedString =
"Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer()))";
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
String expectedPattern =
"Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))";
var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
var memSegScorer =
format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
}
public void testLimits() {

View File

@ -0,0 +1,398 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.internal.vectorization;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
import com.carrotsearch.randomizedtesting.generators.RandomNumbers;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.IntStream;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.NamedThreadFactory;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.junit.BeforeClass;
public class TestVectorScorer extends LuceneTestCase {
private static final double DELTA = 1e-5;
static final FlatVectorsScorer DEFAULT_SCORER = DefaultFlatVectorScorer.INSTANCE;
static final FlatVectorsScorer MEMSEG_SCORER =
VectorizationProvider.lookup(true).getLucene99FlatVectorsScorer();
@BeforeClass
public static void beforeClass() throws Exception {
assumeTrue(
"Test only works when the Memory segment scorer is present.",
MEMSEG_SCORER.getClass() != DEFAULT_SCORER.getClass());
}
public void testSimpleScorer() throws IOException {
testSimpleScorer(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE);
}
public void testSimpleScorerSmallChunkSize() throws IOException {
long maxChunkSize = random().nextLong(4, 16);
testSimpleScorer(maxChunkSize);
}
public void testSimpleScorerMedChunkSize() throws IOException {
// a chunk size where in some vectors will be copied on-heap, while others remain off-heap
testSimpleScorer(64);
}
void testSimpleScorer(long maxChunkSize) throws IOException {
try (Directory dir = new MMapDirectory(createTempDir("testSimpleScorer"), maxChunkSize)) {
for (int dims : List.of(31, 32, 33)) {
// dimensions that, in some scenarios, cross the mmap chunk sizes
byte[][] vectors = new byte[2][dims];
String fileName = "bar-" + dims;
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
for (int i = 0; i < dims; i++) {
vectors[0][i] = (byte) i;
vectors[1][i] = (byte) (dims - i);
}
byte[] bytes = concat(vectors[0], vectors[1]);
out.writeBytes(bytes, 0, bytes.length);
}
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) {
var vectorValues = vectorValues(dims, 2, in, sim);
for (var ords : List.of(List.of(0, 1), List.of(1, 0))) {
int idx0 = ords.get(0);
int idx1 = ords.get(1);
// getRandomVectorScorerSupplier
var scorer1 = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
float expected = scorer1.scorer(idx0).score(idx1);
var scorer2 = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
assertEquals(scorer2.scorer(idx0).score(idx1), expected, DELTA);
// getRandomVectorScorer
var scorer3 = DEFAULT_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]);
assertEquals(scorer3.score(idx1), expected, DELTA);
var scorer4 = MEMSEG_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]);
assertEquals(scorer4.score(idx1), expected, DELTA);
}
}
}
}
}
}
public void testRandomScorer() throws IOException {
testRandomScorer(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_RANDOM_FUNC);
}
public void testRandomScorerMax() throws IOException {
testRandomScorer(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MAX_FUNC);
}
public void testRandomScorerMin() throws IOException {
testRandomScorer(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MIN_FUNC);
}
public void testRandomSmallChunkSize() throws IOException {
long maxChunkSize = randomLongBetween(32, 128);
testRandomScorer(maxChunkSize, BYTE_ARRAY_RANDOM_FUNC);
}
void testRandomScorer(long maxChunkSize, Function<Integer, byte[]> byteArraySupplier)
throws IOException {
try (Directory dir = new MMapDirectory(createTempDir("testRandomScorer"), maxChunkSize)) {
final int dims = randomIntBetween(1, 4096);
final int size = randomIntBetween(2, 100);
final byte[][] vectors = new byte[size][];
String fileName = "foo-" + dims;
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
for (int i = 0; i < size; i++) {
var vec = byteArraySupplier.apply(dims);
out.writeBytes(vec, 0, vec.length);
vectors[i] = vec;
}
}
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
for (int times = 0; times < TIMES; times++) {
for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) {
var vectorValues = vectorValues(dims, size, in, sim);
int idx0 = randomIntBetween(0, size - 1);
int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok.
// getRandomVectorScorerSupplier
var scorer1 = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
float expected = scorer1.scorer(idx0).score(idx1);
var scorer2 = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
assertEquals(scorer2.scorer(idx0).score(idx1), expected, DELTA);
// getRandomVectorScorer
var scorer3 = DEFAULT_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]);
assertEquals(scorer3.score(idx1), expected, DELTA);
var scorer4 = MEMSEG_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]);
assertEquals(scorer4.score(idx1), expected, DELTA);
}
}
}
}
}
public void testRandomSliceSmall() throws IOException {
testRandomSliceImpl(30, 64, 1, BYTE_ARRAY_RANDOM_FUNC);
}
public void testRandomSlice() throws IOException {
int dims = randomIntBetween(1, 4096);
long maxChunkSize = randomLongBetween(32, 128);
int initialOffset = randomIntBetween(1, 129);
testRandomSliceImpl(dims, maxChunkSize, initialOffset, BYTE_ARRAY_RANDOM_FUNC);
}
// Tests with a slice that has a non-zero initial offset
void testRandomSliceImpl(
int dims, long maxChunkSize, int initialOffset, Function<Integer, byte[]> byteArraySupplier)
throws IOException {
try (Directory dir = new MMapDirectory(createTempDir("testRandomSliceImpl"), maxChunkSize)) {
final int size = randomIntBetween(2, 100);
final byte[][] vectors = new byte[size][];
String fileName = "baz-" + dims;
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
byte[] ba = new byte[initialOffset];
out.writeBytes(ba, 0, ba.length);
for (int i = 0; i < size; i++) {
var vec = byteArraySupplier.apply(dims);
out.writeBytes(vec, 0, vec.length);
vectors[i] = vec;
}
}
try (var outter = dir.openInput(fileName, IOContext.DEFAULT);
var in = outter.slice("slice", initialOffset, outter.length() - initialOffset)) {
for (int times = 0; times < TIMES; times++) {
for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) {
var vectorValues = vectorValues(dims, size, in, sim);
int idx0 = randomIntBetween(0, size - 1);
int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok.
// getRandomVectorScorerSupplier
var scorer1 = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
float expected = scorer1.scorer(idx0).score(idx1);
var scorer2 = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
assertEquals(scorer2.scorer(idx0).score(idx1), expected, DELTA);
// getRandomVectorScorer
var scorer3 = DEFAULT_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]);
assertEquals(scorer3.score(idx1), expected, DELTA);
var scorer4 = MEMSEG_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]);
assertEquals(scorer4.score(idx1), expected, DELTA);
}
}
}
}
}
// Tests that copies in threads do not interfere with each other
public void testCopiesAcrossThreads() throws Exception {
final long maxChunkSize = 32;
final int dims = 34; // dimensions that are larger than the chunk size, to force fallback
byte[] vec1 = new byte[dims];
byte[] vec2 = new byte[dims];
IntStream.range(0, dims).forEach(i -> vec1[i] = 1);
IntStream.range(0, dims).forEach(i -> vec2[i] = 2);
try (Directory dir = new MMapDirectory(createTempDir("testRace"), maxChunkSize)) {
String fileName = "biz-" + dims;
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
byte[] bytes = concat(vec1, vec1, vec2, vec2);
out.writeBytes(bytes, 0, bytes.length);
}
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) {
var vectorValues = vectorValues(dims, 4, in, sim);
var scoreSupplier = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
var expectedScore1 = scoreSupplier.scorer(0).score(1);
var expectedScore2 = scoreSupplier.scorer(2).score(3);
var scorer = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
var tasks =
List.<Callable<Optional<Throwable>>>of(
new AssertingScoreCallable(scorer.copy().scorer(0), 1, expectedScore1),
new AssertingScoreCallable(scorer.copy().scorer(2), 3, expectedScore2));
var executor = Executors.newFixedThreadPool(2, new NamedThreadFactory("copiesThreads"));
var results = executor.invokeAll(tasks);
executor.shutdown();
assertTrue(executor.awaitTermination(30, TimeUnit.SECONDS));
assertEquals(results.stream().filter(Predicate.not(Future::isDone)).count(), 0L);
for (var res : results) {
assertTrue("Unexpected exception" + res.get(), res.get().isEmpty());
}
}
}
}
}
// A callable that scores the given ord and scorer and asserts the expected result.
static class AssertingScoreCallable implements Callable<Optional<Throwable>> {
final RandomVectorScorer scorer;
final int ord;
final float expectedScore;
AssertingScoreCallable(RandomVectorScorer scorer, int ord, float expectedScore) {
this.scorer = scorer;
this.ord = ord;
this.expectedScore = expectedScore;
}
@Override
public Optional<Throwable> call() throws Exception {
try {
for (int i = 0; i < 100; i++) {
assertEquals(scorer.score(ord), expectedScore, DELTA);
}
} catch (Throwable t) {
return Optional.of(t);
}
return Optional.empty();
}
}
// Tests with a large amount of data (> 2GB), which ensures that data offsets do not overflow
@Nightly
public void testLarge() throws IOException {
try (Directory dir = new MMapDirectory(createTempDir("testLarge"))) {
final int dims = 8192;
final int size = 262500;
final String fileName = "large-" + dims;
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
for (int i = 0; i < size; i++) {
var vec = vector(i, dims);
out.writeBytes(vec, 0, vec.length);
}
}
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
assert in.length() > Integer.MAX_VALUE;
for (int times = 0; times < TIMES; times++) {
for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) {
var vectorValues = vectorValues(dims, size, in, sim);
int ord1 = randomIntBetween(0, size - 1);
int ord2 = size - 1;
for (var ords : List.of(List.of(ord1, ord2), List.of(ord2, ord1))) {
int idx0 = ords.getFirst();
int idx1 = ords.getLast();
// getRandomVectorScorerSupplier
var scorer1 = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
float expected = scorer1.scorer(idx0).score(idx1);
var scorer2 = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
assertEquals(scorer2.scorer(idx0).score(idx1), expected, DELTA);
// getRandomVectorScorer
var query = vector(idx0, dims);
var scorer3 = DEFAULT_SCORER.getRandomVectorScorer(sim, vectorValues, query);
assertEquals(scorer3.score(idx1), expected, DELTA);
var scorer4 = MEMSEG_SCORER.getRandomVectorScorer(sim, vectorValues, query);
assertEquals(scorer4.score(idx1), expected, DELTA);
}
}
}
}
}
}
RandomAccessVectorValues vectorValues(
int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
return new OffHeapByteVectorValues.DenseOffHeapVectorValues(
dims, size, in.slice("byteValues", 0, in.length()), dims, MEMSEG_SCORER, sim);
}
// creates the vector based on the given ordinal, which is reproducible given the ord and dims
static byte[] vector(int ord, int dims) {
var random = new Random(Objects.hash(ord, dims));
byte[] ba = new byte[dims];
for (int i = 0; i < dims; i++) {
ba[i] = (byte) RandomNumbers.randomIntBetween(random, Byte.MIN_VALUE, Byte.MAX_VALUE);
}
return ba;
}
/** Concatenates byte arrays. */
static byte[] concat(byte[]... arrays) throws IOException {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
for (var ba : arrays) {
baos.write(ba);
}
return baos.toByteArray();
}
}
static int randomIntBetween(int minInclusive, int maxInclusive) {
return RandomNumbers.randomIntBetween(random(), minInclusive, maxInclusive);
}
static long randomLongBetween(long minInclusive, long maxInclusive) {
return RandomNumbers.randomLongBetween(random(), minInclusive, maxInclusive);
}
static Function<Integer, byte[]> BYTE_ARRAY_RANDOM_FUNC =
size -> {
byte[] ba = new byte[size];
for (int i = 0; i < size; i++) {
ba[i] = (byte) random().nextInt();
}
return ba;
};
static Function<Integer, byte[]> BYTE_ARRAY_MAX_FUNC =
size -> {
byte[] ba = new byte[size];
Arrays.fill(ba, Byte.MAX_VALUE);
return ba;
};
static Function<Integer, byte[]> BYTE_ARRAY_MIN_FUNC =
size -> {
byte[] ba = new byte[size];
Arrays.fill(ba, Byte.MIN_VALUE);
return ba;
};
static final int TIMES = 100; // a loop iteration times
}

View File

@ -50,6 +50,7 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.codecs.asserting.AssertingCodec;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.store.BaseDirectoryWrapper;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BitSet;
@ -77,6 +78,13 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
abstract Field getKnnVectorField(String name, float[] vector);
/**
* Creates a new directory. Subclasses can override to test different directory implementations.
*/
protected BaseDirectoryWrapper newDirectoryForTest() {
return LuceneTestCase.newDirectory(random());
}
public void testEquals() {
AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10);
Query filter1 = new TermQuery(new Term("id", "id1"));
@ -308,7 +316,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
}
public void testScoreCosine() throws IOException {
try (Directory d = newDirectory()) {
try (Directory d = newDirectoryForTest()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 1; j <= 5; j++) {
Document doc = new Document();
@ -383,7 +391,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
}
public void testExplain() throws IOException {
try (Directory d = newDirectory()) {
try (Directory d = newDirectoryForTest()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
@ -410,7 +418,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
}
public void testExplainMultipleSegments() throws IOException {
try (Directory d = newDirectory()) {
try (Directory d = newDirectoryForTest()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
@ -443,7 +451,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
* number of top K documents, but no more than K documents in total (otherwise we might occasionally
* randomly fail to find one).
*/
try (Directory d = newDirectory()) {
try (Directory d = newDirectoryForTest()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
int r = 0;
for (int i = 0; i < 5; i++) {
@ -479,7 +487,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
int dimension = atLeast(5);
int numIters = atLeast(10);
boolean everyDocHasAVector = random().nextBoolean();
try (Directory d = newDirectory()) {
try (Directory d = newDirectoryForTest()) {
RandomIndexWriter w = new RandomIndexWriter(random(), d);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
@ -518,7 +526,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
int numDocs = 1000;
int dimension = atLeast(5);
int numIters = atLeast(10);
try (Directory d = newDirectory()) {
try (Directory d = newDirectoryForTest()) {
// Always use the default kNN format to have predictable behavior around when it hits
// visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
// format
@ -604,7 +612,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
public void testFilterWithSameScore() throws IOException {
int numDocs = 100;
int dimension = atLeast(5);
try (Directory d = newDirectory()) {
try (Directory d = newDirectoryForTest()) {
// Always use the default kNN format to have predictable behavior around when it hits
// visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
// format
@ -644,7 +652,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
}
public void testDeletes() throws IOException {
try (Directory dir = newDirectory();
try (Directory dir = newDirectoryForTest();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
final int numDocs = atLeast(100);
final int dim = 30;
@ -688,7 +696,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
}
public void testAllDeletes() throws IOException {
try (Directory dir = newDirectory();
try (Directory dir = newDirectoryForTest();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
final int numDocs = atLeast(100);
final int dim = 30;
@ -717,7 +725,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
*/
public void testNoLiveDocsReader() throws IOException {
IndexWriterConfig iwc = newIndexWriterConfig();
try (Directory dir = newDirectory();
try (Directory dir = newDirectoryForTest();
IndexWriter w = new IndexWriter(dir, iwc)) {
final int numDocs = 10;
final int dim = 30;
@ -745,7 +753,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
*/
public void testBitSetQuery() throws IOException {
IndexWriterConfig iwc = newIndexWriterConfig();
try (Directory dir = newDirectory();
try (Directory dir = newDirectoryForTest();
IndexWriter w = new IndexWriter(dir, iwc)) {
final int numDocs = 100;
final int dim = 30;
@ -853,7 +861,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
Directory getIndexStore(
String field, VectorSimilarityFunction vectorSimilarityFunction, float[]... contents)
throws IOException {
Directory indexStore = newDirectory();
Directory indexStore = newDirectoryForTest();
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
for (int i = 0; i < contents.length; ++i) {
Document doc = new Document();
@ -886,7 +894,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
* preserving the order of the added documents.
*/
private Directory getStableIndexStore(String field, float[]... contents) throws IOException {
Directory indexStore = newDirectory();
Directory indexStore = newDirectoryForTest();
try (IndexWriter writer = new IndexWriter(indexStore, new IndexWriterConfig())) {
for (int i = 0; i < contents.length; ++i) {
Document doc = new Document();
@ -1031,7 +1039,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
}
public void testSameFieldDifferentFormats() throws IOException {
try (Directory directory = newDirectory()) {
try (Directory directory = newDirectoryForTest()) {
MockAnalyzer mockAnalyzer = new MockAnalyzer(random());
IndexWriterConfig iwc = newIndexWriterConfig(mockAnalyzer);
KnnVectorsFormat format1 = randomVectorFormat(VectorEncoding.FLOAT32);

View File

@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.io.UncheckedIOException;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.tests.store.BaseDirectoryWrapper;
import org.apache.lucene.tests.store.MockDirectoryWrapper;
public class TestKnnByteVectorQueryMMap extends TestKnnByteVectorQuery {
@Override
protected BaseDirectoryWrapper newDirectoryForTest() {
try {
return new MockDirectoryWrapper(
random(), new MMapDirectory(createTempDir("TestKnnByteVectorQueryMMap")));
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
}

View File

@ -20,6 +20,7 @@ import java.io.Closeable;
import java.io.IOException;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.internal.tests.TestSecrets;
import org.apache.lucene.store.FilterIndexInput;
import org.apache.lucene.store.IndexInput;
@ -27,6 +28,11 @@ import org.apache.lucene.store.IndexInput;
* Used by MockDirectoryWrapper to create an input stream that keeps track of when it's been closed.
*/
public class MockIndexInputWrapper extends FilterIndexInput {
static {
TestSecrets.getFilterInputIndexAccess().addTestFilterType(MockIndexInputWrapper.class);
}
private MockDirectoryWrapper dir;
final String name;
private volatile boolean closed;

View File

@ -17,6 +17,7 @@
package org.apache.lucene.tests.store;
import java.io.IOException;
import org.apache.lucene.internal.tests.TestSecrets;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.SuppressForbidden;
import org.apache.lucene.util.ThreadInterruptedException;
@ -28,6 +29,11 @@ import org.apache.lucene.util.ThreadInterruptedException;
*/
class SlowClosingMockIndexInputWrapper extends MockIndexInputWrapper {
static {
TestSecrets.getFilterInputIndexAccess()
.addTestFilterType(SlowClosingMockIndexInputWrapper.class);
}
public SlowClosingMockIndexInputWrapper(
MockDirectoryWrapper dir, String name, IndexInput delegate) {
super(dir, name, delegate, null);

View File

@ -17,6 +17,7 @@
package org.apache.lucene.tests.store;
import java.io.IOException;
import org.apache.lucene.internal.tests.TestSecrets;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.SuppressForbidden;
import org.apache.lucene.util.ThreadInterruptedException;
@ -27,6 +28,11 @@ import org.apache.lucene.util.ThreadInterruptedException;
*/
class SlowOpeningMockIndexInputWrapper extends MockIndexInputWrapper {
static {
TestSecrets.getFilterInputIndexAccess()
.addTestFilterType(SlowOpeningMockIndexInputWrapper.class);
}
@SuppressForbidden(reason = "Thread sleep")
public SlowOpeningMockIndexInputWrapper(
MockDirectoryWrapper dir, String name, IndexInput delegate) throws IOException {