mirror of https://github.com/apache/lucene.git
Add a MemorySegment Vector scorer - for scoring without copying on-heap (#13339)
Add a MemorySegment Vector scorer - for scoring without copying on-heap. The vector scorer loads values directly from the backing memory segment when available. Otherwise, if the vector data spans across segments the scorer copies the vector data on-heap. A benchmark shows ~2x performance improvement of this scorer over the default copy-on-heap scorer. The scorer currently only operates on vectors with an element size of byte. We can evaluate if and how to support floats separately.
This commit is contained in:
parent
f70999980c
commit
05f04aa08a
|
@ -348,6 +348,8 @@ Optimizations
|
|||
|
||||
* GITHUB#13392: Replace Map<Long, Object> by primitive LongObjectHashMap. (Bruno Roustant)
|
||||
|
||||
* GITHUB#13339: Add a MemorySegment Vector scorer - for scoring without copying on-heap (Chris Hegarty)
|
||||
|
||||
Bug Fixes
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.benchmark.jmh;
|
||||
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.store.MMapDirectory;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.openjdk.jmh.annotations.*;
|
||||
|
||||
@BenchmarkMode(Mode.Throughput)
|
||||
@OutputTimeUnit(TimeUnit.MICROSECONDS)
|
||||
@State(Scope.Benchmark)
|
||||
// first iteration is complete garbage, so make sure we really warmup
|
||||
@Warmup(iterations = 4, time = 1)
|
||||
// real iterations. not useful to spend tons of time here, better to fork more
|
||||
@Measurement(iterations = 5, time = 1)
|
||||
// engage some noise reduction
|
||||
@Fork(
|
||||
value = 3,
|
||||
jvmArgsAppend = {"-Xmx2g", "-Xms2g", "-XX:+AlwaysPreTouch"})
|
||||
public class VectorScorerBenchmark {
|
||||
|
||||
@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
|
||||
int size;
|
||||
|
||||
Directory dir;
|
||||
IndexInput in;
|
||||
RandomAccessVectorValues vectorValues;
|
||||
byte[] vec1, vec2;
|
||||
RandomVectorScorer scorer;
|
||||
|
||||
@Setup(Level.Iteration)
|
||||
public void init() throws IOException {
|
||||
vec1 = new byte[size];
|
||||
vec2 = new byte[size];
|
||||
ThreadLocalRandom.current().nextBytes(vec1);
|
||||
ThreadLocalRandom.current().nextBytes(vec2);
|
||||
|
||||
dir = new MMapDirectory(Files.createTempDirectory("VectorScorerBenchmark"));
|
||||
try (IndexOutput out = dir.createOutput("vector.data", IOContext.DEFAULT)) {
|
||||
out.writeBytes(vec1, 0, vec1.length);
|
||||
out.writeBytes(vec2, 0, vec2.length);
|
||||
}
|
||||
in = dir.openInput("vector.data", IOContext.DEFAULT);
|
||||
vectorValues = vectorValues(size, 2, in, DOT_PRODUCT);
|
||||
scorer =
|
||||
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
|
||||
.getRandomVectorScorerSupplier(DOT_PRODUCT, vectorValues)
|
||||
.scorer(0);
|
||||
}
|
||||
|
||||
@TearDown
|
||||
public void teardown() throws IOException {
|
||||
IOUtils.close(dir, in);
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public float binaryDotProductDefault() throws IOException {
|
||||
return scorer.score(1);
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
|
||||
public float binaryDotProductMemSeg() throws IOException {
|
||||
return scorer.score(1);
|
||||
}
|
||||
|
||||
static RandomAccessVectorValues vectorValues(
|
||||
int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
|
||||
return new OffHeapByteVectorValues.DenseOffHeapVectorValues(
|
||||
dims, size, in.slice("test", 0, in.length()), dims, new ThrowingFlatVectorScorer(), sim);
|
||||
}
|
||||
|
||||
static final class ThrowingFlatVectorScorer implements FlatVectorsScorer {
|
||||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
|
||||
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
byte[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -29,6 +29,9 @@ import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
|||
* @lucene.experimental
|
||||
*/
|
||||
public class DefaultFlatVectorScorer implements FlatVectorsScorer {
|
||||
|
||||
public static final DefaultFlatVectorScorer INSTANCE = new DefaultFlatVectorScorer();
|
||||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
|
||||
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.codecs.hnsw;
|
||||
|
||||
import org.apache.lucene.internal.vectorization.VectorizationProvider;
|
||||
|
||||
/**
|
||||
* Utilities for {@link FlatVectorsScorer}.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class FlatVectorScorerUtil {
|
||||
|
||||
private static final VectorizationProvider IMPL = VectorizationProvider.getInstance();
|
||||
|
||||
private FlatVectorScorerUtil() {}
|
||||
|
||||
/**
|
||||
* Returns a FlatVectorsScorer that supports the Lucene99 format. Scorers retrieved through this
|
||||
* method may be optimized on certain platforms. Otherwise, a DefaultFlatVectorScorer is returned.
|
||||
*/
|
||||
public static FlatVectorsScorer getLucene99FlatVectorsScorer() {
|
||||
return IMPL.getLucene99FlatVectorsScorer();
|
||||
}
|
||||
}
|
|
@ -22,7 +22,7 @@ import java.util.concurrent.ExecutorService;
|
|||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.MergePolicy;
|
||||
|
@ -139,7 +139,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
|||
|
||||
/** The format for storing, reading, merging vectors on disk */
|
||||
private static final FlatVectorsFormat flatVectorsFormat =
|
||||
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
|
||||
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
|
||||
|
||||
private final int numMergeWorkers;
|
||||
private final TaskExecutor mergeExec;
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.lucene.codecs.lucene99;
|
|||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
|
||||
|
@ -48,7 +49,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
|
|||
static final String VECTOR_DATA_EXTENSION = "veq";
|
||||
|
||||
private static final FlatVectorsFormat rawVectorFormat =
|
||||
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
|
||||
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
|
||||
|
||||
/** The minimum confidence interval */
|
||||
private static final float MINIMUM_CONFIDENCE_INTERVAL = 0.9f;
|
||||
|
@ -101,7 +102,8 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
|
|||
this.bits = (byte) bits;
|
||||
this.confidenceInterval = confidenceInterval;
|
||||
this.compress = compress;
|
||||
this.flatVectorScorer = new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer());
|
||||
this.flatVectorScorer =
|
||||
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
|
||||
}
|
||||
|
||||
public static float calculateDefaultConfidenceInterval(int vectorDimension) {
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.internal.tests;
|
||||
|
||||
import org.apache.lucene.store.FilterIndexInput;
|
||||
|
||||
/**
|
||||
* Access to {@link org.apache.lucene.store.FilterIndexInput} internals exposed to the test
|
||||
* framework.
|
||||
*
|
||||
* @lucene.internal
|
||||
*/
|
||||
public interface FilterIndexInputAccess {
|
||||
/** Adds the given test FilterIndexInput class. */
|
||||
void addTestFilterType(Class<? extends FilterIndexInput> cls);
|
||||
}
|
|
@ -23,6 +23,7 @@ import org.apache.lucene.index.ConcurrentMergeScheduler;
|
|||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.SegmentReader;
|
||||
import org.apache.lucene.store.FilterIndexInput;
|
||||
|
||||
/**
|
||||
* A set of static methods returning accessors for internal, package-private functionality in
|
||||
|
@ -48,12 +49,14 @@ public final class TestSecrets {
|
|||
ensureInitialized.accept(ConcurrentMergeScheduler.class);
|
||||
ensureInitialized.accept(SegmentReader.class);
|
||||
ensureInitialized.accept(IndexWriter.class);
|
||||
ensureInitialized.accept(FilterIndexInput.class);
|
||||
}
|
||||
|
||||
private static IndexPackageAccess indexPackageAccess;
|
||||
private static ConcurrentMergeSchedulerAccess cmsAccess;
|
||||
private static SegmentReaderAccess segmentReaderAccess;
|
||||
private static IndexWriterAccess indexWriterAccess;
|
||||
private static FilterIndexInputAccess filterIndexInputAccess;
|
||||
|
||||
private TestSecrets() {}
|
||||
|
||||
|
@ -81,6 +84,12 @@ public final class TestSecrets {
|
|||
return Objects.requireNonNull(indexWriterAccess);
|
||||
}
|
||||
|
||||
/** Return the accessor to internal secrets for an {@link FilterIndexInput}. */
|
||||
public static FilterIndexInputAccess getFilterInputIndexAccess() {
|
||||
ensureCaller();
|
||||
return Objects.requireNonNull(filterIndexInputAccess);
|
||||
}
|
||||
|
||||
/** For internal initialization only. */
|
||||
public static void setIndexWriterAccess(IndexWriterAccess indexWriterAccess) {
|
||||
ensureNull(TestSecrets.indexWriterAccess);
|
||||
|
@ -105,6 +114,12 @@ public final class TestSecrets {
|
|||
TestSecrets.segmentReaderAccess = segmentReaderAccess;
|
||||
}
|
||||
|
||||
/** For internal initialization only. */
|
||||
public static void setFilterInputIndexAccess(FilterIndexInputAccess filterIndexInputAccess) {
|
||||
ensureNull(TestSecrets.filterIndexInputAccess);
|
||||
TestSecrets.filterIndexInputAccess = filterIndexInputAccess;
|
||||
}
|
||||
|
||||
private static void ensureNull(Object ob) {
|
||||
if (ob != null) {
|
||||
throw new AssertionError(
|
||||
|
|
|
@ -17,6 +17,9 @@
|
|||
|
||||
package org.apache.lucene.internal.vectorization;
|
||||
|
||||
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
|
||||
/** Default provider returning scalar implementations. */
|
||||
final class DefaultVectorizationProvider extends VectorizationProvider {
|
||||
|
||||
|
@ -30,4 +33,9 @@ final class DefaultVectorizationProvider extends VectorizationProvider {
|
|||
public VectorUtilSupport getVectorUtilSupport() {
|
||||
return vectorUtilSupport;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FlatVectorsScorer getLucene99FlatVectorsScorer() {
|
||||
return DefaultFlatVectorScorer.INSTANCE;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ import java.util.Set;
|
|||
import java.util.function.Predicate;
|
||||
import java.util.logging.Logger;
|
||||
import java.util.stream.Stream;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.util.Constants;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
|
@ -91,6 +92,9 @@ public abstract class VectorizationProvider {
|
|||
*/
|
||||
public abstract VectorUtilSupport getVectorUtilSupport();
|
||||
|
||||
/** Returns a FlatVectorsScorer that supports the Lucene99 format. */
|
||||
public abstract FlatVectorsScorer getLucene99FlatVectorsScorer();
|
||||
|
||||
// *** Lookup mechanism: ***
|
||||
|
||||
private static final Logger LOG = Logger.getLogger(VectorizationProvider.class.getName());
|
||||
|
@ -177,7 +181,10 @@ public abstract class VectorizationProvider {
|
|||
}
|
||||
|
||||
// add all possible callers here as FQCN:
|
||||
private static final Set<String> VALID_CALLERS = Set.of("org.apache.lucene.util.VectorUtil");
|
||||
private static final Set<String> VALID_CALLERS =
|
||||
Set.of(
|
||||
"org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil",
|
||||
"org.apache.lucene.util.VectorUtil");
|
||||
|
||||
private static void ensureCaller() {
|
||||
final boolean validCaller =
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
package org.apache.lucene.store;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import org.apache.lucene.internal.tests.TestSecrets;
|
||||
|
||||
/**
|
||||
* IndexInput implementation that delegates calls to another directory. This class can be used to
|
||||
|
@ -29,6 +31,12 @@ import java.io.IOException;
|
|||
*/
|
||||
public class FilterIndexInput extends IndexInput {
|
||||
|
||||
static final CopyOnWriteArrayList<Class<?>> TEST_FILTER_INPUTS = new CopyOnWriteArrayList<>();
|
||||
|
||||
static {
|
||||
TestSecrets.setFilterInputIndexAccess(TEST_FILTER_INPUTS::add);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unwraps all FilterIndexInputs until the first non-FilterIndexInput IndexInput instance and
|
||||
* returns it
|
||||
|
@ -40,6 +48,17 @@ public class FilterIndexInput extends IndexInput {
|
|||
return in;
|
||||
}
|
||||
|
||||
/**
|
||||
* Unwraps all test FilterIndexInputs until the first non-test FilterIndexInput IndexInput
|
||||
* instance and returns it
|
||||
*/
|
||||
public static IndexInput unwrapOnlyTest(IndexInput in) {
|
||||
while (in instanceof FilterIndexInput && TEST_FILTER_INPUTS.contains(in.getClass())) {
|
||||
in = ((FilterIndexInput) in).in;
|
||||
}
|
||||
return in;
|
||||
}
|
||||
|
||||
protected final IndexInput in;
|
||||
|
||||
/** Creates a FilterIndexInput with a resource description and wrapped delegate IndexInput */
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.internal.vectorization;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.foreign.MemorySegment;
|
||||
import java.util.Optional;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.FilterIndexInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.MemorySegmentAccessInput;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
|
||||
abstract sealed class Lucene99MemorySegmentByteVectorScorer
|
||||
extends RandomVectorScorer.AbstractRandomVectorScorer {
|
||||
|
||||
final int vectorByteSize;
|
||||
final MemorySegmentAccessInput input;
|
||||
final MemorySegment query;
|
||||
byte[] scratch;
|
||||
|
||||
/**
|
||||
* Return an optional whose value, if present, is the scorer. Otherwise, an empty optional is
|
||||
* returned.
|
||||
*/
|
||||
public static Optional<Lucene99MemorySegmentByteVectorScorer> create(
|
||||
VectorSimilarityFunction type,
|
||||
IndexInput input,
|
||||
RandomAccessVectorValues values,
|
||||
byte[] queryVector) {
|
||||
input = FilterIndexInput.unwrapOnlyTest(input);
|
||||
if (!(input instanceof MemorySegmentAccessInput msInput)) {
|
||||
return Optional.empty();
|
||||
}
|
||||
checkInvariants(values.size(), values.getVectorByteLength(), input);
|
||||
return switch (type) {
|
||||
case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector));
|
||||
case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector));
|
||||
case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector));
|
||||
case MAXIMUM_INNER_PRODUCT -> Optional.of(
|
||||
new MaxInnerProductScorer(msInput, values, queryVector));
|
||||
};
|
||||
}
|
||||
|
||||
Lucene99MemorySegmentByteVectorScorer(
|
||||
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] queryVector) {
|
||||
super(values);
|
||||
this.input = input;
|
||||
this.vectorByteSize = values.getVectorByteLength();
|
||||
this.query = MemorySegment.ofArray(queryVector);
|
||||
}
|
||||
|
||||
final MemorySegment getSegment(int ord) throws IOException {
|
||||
checkOrdinal(ord);
|
||||
long byteOffset = (long) ord * vectorByteSize;
|
||||
MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize);
|
||||
if (seg == null) {
|
||||
if (scratch == null) {
|
||||
scratch = new byte[vectorByteSize];
|
||||
}
|
||||
input.readBytes(byteOffset, scratch, 0, vectorByteSize);
|
||||
seg = MemorySegment.ofArray(scratch);
|
||||
}
|
||||
return seg;
|
||||
}
|
||||
|
||||
static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
|
||||
if (input.length() < (long) vectorByteLength * maxOrd) {
|
||||
throw new IllegalArgumentException("input length is less than expected vector data");
|
||||
}
|
||||
}
|
||||
|
||||
final void checkOrdinal(int ord) {
|
||||
if (ord < 0 || ord >= maxOrd()) {
|
||||
throw new IllegalArgumentException("illegal ordinal: " + ord);
|
||||
}
|
||||
}
|
||||
|
||||
static final class CosineScorer extends Lucene99MemorySegmentByteVectorScorer {
|
||||
CosineScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
||||
super(input, values, query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int node) throws IOException {
|
||||
checkOrdinal(node);
|
||||
float raw = PanamaVectorUtilSupport.cosine(query, getSegment(node));
|
||||
return (1 + raw) / 2;
|
||||
}
|
||||
}
|
||||
|
||||
static final class DotProductScorer extends Lucene99MemorySegmentByteVectorScorer {
|
||||
DotProductScorer(
|
||||
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
||||
super(input, values, query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int node) throws IOException {
|
||||
checkOrdinal(node);
|
||||
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
|
||||
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
|
||||
return 0.5f + raw / (float) (query.byteSize() * (1 << 15));
|
||||
}
|
||||
}
|
||||
|
||||
static final class EuclideanScorer extends Lucene99MemorySegmentByteVectorScorer {
|
||||
EuclideanScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
||||
super(input, values, query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int node) throws IOException {
|
||||
checkOrdinal(node);
|
||||
float raw = PanamaVectorUtilSupport.squareDistance(query, getSegment(node));
|
||||
return 1 / (1f + raw);
|
||||
}
|
||||
}
|
||||
|
||||
static final class MaxInnerProductScorer extends Lucene99MemorySegmentByteVectorScorer {
|
||||
MaxInnerProductScorer(
|
||||
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
||||
super(input, values, query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int node) throws IOException {
|
||||
checkOrdinal(node);
|
||||
float raw = PanamaVectorUtilSupport.dotProduct(query, getSegment(node));
|
||||
if (raw < 0) {
|
||||
return 1 / (1 + -1 * raw);
|
||||
}
|
||||
return raw + 1;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,210 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.internal.vectorization;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.foreign.MemorySegment;
|
||||
import java.util.Optional;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.FilterIndexInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.MemorySegmentAccessInput;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
|
||||
/** A score supplier of vectors whose element size is byte. */
|
||||
public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
|
||||
implements RandomVectorScorerSupplier {
|
||||
final int vectorByteSize;
|
||||
final int maxOrd;
|
||||
final MemorySegmentAccessInput input;
|
||||
final RandomAccessVectorValues values; // to support ordToDoc/getAcceptOrds
|
||||
byte[] scratch1, scratch2;
|
||||
|
||||
/**
|
||||
* Return an optional whose value, if present, is the scorer supplier. Otherwise, an empty
|
||||
* optional is returned.
|
||||
*/
|
||||
static Optional<RandomVectorScorerSupplier> create(
|
||||
VectorSimilarityFunction type, IndexInput input, RandomAccessVectorValues values) {
|
||||
input = FilterIndexInput.unwrapOnlyTest(input);
|
||||
if (!(input instanceof MemorySegmentAccessInput msInput)) {
|
||||
return Optional.empty();
|
||||
}
|
||||
checkInvariants(values.size(), values.getVectorByteLength(), input);
|
||||
return switch (type) {
|
||||
case COSINE -> Optional.of(new CosineSupplier(msInput, values));
|
||||
case DOT_PRODUCT -> Optional.of(new DotProductSupplier(msInput, values));
|
||||
case EUCLIDEAN -> Optional.of(new EuclideanSupplier(msInput, values));
|
||||
case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(msInput, values));
|
||||
};
|
||||
}
|
||||
|
||||
Lucene99MemorySegmentByteVectorScorerSupplier(
|
||||
MemorySegmentAccessInput input, RandomAccessVectorValues values) {
|
||||
this.input = input;
|
||||
this.values = values;
|
||||
this.vectorByteSize = values.getVectorByteLength();
|
||||
this.maxOrd = values.size();
|
||||
}
|
||||
|
||||
static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
|
||||
if (input.length() < (long) vectorByteLength * maxOrd) {
|
||||
throw new IllegalArgumentException("input length is less than expected vector data");
|
||||
}
|
||||
}
|
||||
|
||||
final void checkOrdinal(int ord) {
|
||||
if (ord < 0 || ord >= maxOrd) {
|
||||
throw new IllegalArgumentException("illegal ordinal: " + ord);
|
||||
}
|
||||
}
|
||||
|
||||
final MemorySegment getFirstSegment(int ord) throws IOException {
|
||||
long byteOffset = (long) ord * vectorByteSize;
|
||||
MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize);
|
||||
if (seg == null) {
|
||||
if (scratch1 == null) {
|
||||
scratch1 = new byte[vectorByteSize];
|
||||
}
|
||||
input.readBytes(byteOffset, scratch1, 0, vectorByteSize);
|
||||
seg = MemorySegment.ofArray(scratch1);
|
||||
}
|
||||
return seg;
|
||||
}
|
||||
|
||||
final MemorySegment getSecondSegment(int ord) throws IOException {
|
||||
long byteOffset = (long) ord * vectorByteSize;
|
||||
MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize);
|
||||
if (seg == null) {
|
||||
if (scratch2 == null) {
|
||||
scratch2 = new byte[vectorByteSize];
|
||||
}
|
||||
input.readBytes(byteOffset, scratch2, 0, vectorByteSize);
|
||||
seg = MemorySegment.ofArray(scratch2);
|
||||
}
|
||||
return seg;
|
||||
}
|
||||
|
||||
static final class CosineSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
|
||||
|
||||
CosineSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
|
||||
super(input, values);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer scorer(int ord) {
|
||||
checkOrdinal(ord);
|
||||
return new RandomVectorScorer.AbstractRandomVectorScorer(values) {
|
||||
@Override
|
||||
public float score(int node) throws IOException {
|
||||
checkOrdinal(node);
|
||||
float raw = PanamaVectorUtilSupport.cosine(getFirstSegment(ord), getSecondSegment(node));
|
||||
return (1 + raw) / 2;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public CosineSupplier copy() throws IOException {
|
||||
return new CosineSupplier(input.clone(), values);
|
||||
}
|
||||
}
|
||||
|
||||
static final class DotProductSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
|
||||
|
||||
DotProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
|
||||
super(input, values);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer scorer(int ord) {
|
||||
checkOrdinal(ord);
|
||||
return new RandomVectorScorer.AbstractRandomVectorScorer(values) {
|
||||
@Override
|
||||
public float score(int node) throws IOException {
|
||||
checkOrdinal(node);
|
||||
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
|
||||
float raw =
|
||||
PanamaVectorUtilSupport.dotProduct(getFirstSegment(ord), getSecondSegment(node));
|
||||
return 0.5f + raw / (float) (values.dimension() * (1 << 15));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public DotProductSupplier copy() throws IOException {
|
||||
return new DotProductSupplier(input.clone(), values);
|
||||
}
|
||||
}
|
||||
|
||||
static final class EuclideanSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
|
||||
|
||||
EuclideanSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
|
||||
super(input, values);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer scorer(int ord) {
|
||||
checkOrdinal(ord);
|
||||
return new RandomVectorScorer.AbstractRandomVectorScorer(values) {
|
||||
@Override
|
||||
public float score(int node) throws IOException {
|
||||
checkOrdinal(node);
|
||||
float raw =
|
||||
PanamaVectorUtilSupport.squareDistance(getFirstSegment(ord), getSecondSegment(node));
|
||||
return 1 / (1f + raw);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public EuclideanSupplier copy() throws IOException {
|
||||
return new EuclideanSupplier(input.clone(), values);
|
||||
}
|
||||
}
|
||||
|
||||
static final class MaxInnerProductSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
|
||||
|
||||
MaxInnerProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
|
||||
super(input, values);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer scorer(int ord) {
|
||||
checkOrdinal(ord);
|
||||
return new RandomVectorScorer.AbstractRandomVectorScorer(values) {
|
||||
@Override
|
||||
public float score(int node) throws IOException {
|
||||
checkOrdinal(node);
|
||||
float raw =
|
||||
PanamaVectorUtilSupport.dotProduct(getFirstSegment(ord), getSecondSegment(node));
|
||||
if (raw < 0) {
|
||||
return 1 / (1 + -1 * raw);
|
||||
}
|
||||
return raw + 1;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public MaxInnerProductSupplier copy() throws IOException {
|
||||
return new MaxInnerProductSupplier(input.clone(), values);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.internal.vectorization;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
|
||||
public class Lucene99MemorySegmentFlatVectorsScorer implements FlatVectorsScorer {
|
||||
|
||||
public static final Lucene99MemorySegmentFlatVectorsScorer INSTANCE =
|
||||
new Lucene99MemorySegmentFlatVectorsScorer(DefaultFlatVectorScorer.INSTANCE);
|
||||
|
||||
private final FlatVectorsScorer delegate;
|
||||
|
||||
private Lucene99MemorySegmentFlatVectorsScorer(FlatVectorsScorer delegate) {
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
|
||||
VectorSimilarityFunction similarityType, RandomAccessVectorValues vectorValues)
|
||||
throws IOException {
|
||||
// currently only supports binary vectors
|
||||
if (vectorValues instanceof RandomAccessVectorValues.Bytes && vectorValues.getSlice() != null) {
|
||||
var scorer =
|
||||
Lucene99MemorySegmentByteVectorScorerSupplier.create(
|
||||
similarityType, vectorValues.getSlice(), vectorValues);
|
||||
if (scorer.isPresent()) {
|
||||
return scorer.get();
|
||||
}
|
||||
}
|
||||
return delegate.getRandomVectorScorerSupplier(similarityType, vectorValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityType,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
float[] target)
|
||||
throws IOException {
|
||||
// currently only supports binary vectors, so always delegate
|
||||
return delegate.getRandomVectorScorer(similarityType, vectorValues, target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityType,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
byte[] queryVector)
|
||||
throws IOException {
|
||||
checkDimensions(queryVector.length, vectorValues.dimension());
|
||||
if (vectorValues instanceof RandomAccessVectorValues.Bytes && vectorValues.getSlice() != null) {
|
||||
var scorer =
|
||||
Lucene99MemorySegmentByteVectorScorer.create(
|
||||
similarityType, vectorValues.getSlice(), vectorValues, queryVector);
|
||||
if (scorer.isPresent()) {
|
||||
return scorer.get();
|
||||
}
|
||||
}
|
||||
return delegate.getRandomVectorScorer(similarityType, vectorValues, queryVector);
|
||||
}
|
||||
|
||||
static void checkDimensions(int queryLen, int fieldLen) {
|
||||
if (queryLen != fieldLen) {
|
||||
throw new IllegalArgumentException(
|
||||
"vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Lucene99MemorySegmentFlatVectorsScorer()";
|
||||
}
|
||||
}
|
|
@ -16,6 +16,8 @@
|
|||
*/
|
||||
package org.apache.lucene.internal.vectorization;
|
||||
|
||||
import static java.lang.foreign.ValueLayout.JAVA_BYTE;
|
||||
import static java.nio.ByteOrder.LITTLE_ENDIAN;
|
||||
import static jdk.incubator.vector.VectorOperators.ADD;
|
||||
import static jdk.incubator.vector.VectorOperators.B2I;
|
||||
import static jdk.incubator.vector.VectorOperators.B2S;
|
||||
|
@ -23,6 +25,7 @@ import static jdk.incubator.vector.VectorOperators.LSHR;
|
|||
import static jdk.incubator.vector.VectorOperators.S2I;
|
||||
import static jdk.incubator.vector.VectorOperators.ZERO_EXTEND_B2S;
|
||||
|
||||
import java.lang.foreign.MemorySegment;
|
||||
import jdk.incubator.vector.ByteVector;
|
||||
import jdk.incubator.vector.FloatVector;
|
||||
import jdk.incubator.vector.IntVector;
|
||||
|
@ -307,39 +310,44 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
|
||||
@Override
|
||||
public int dotProduct(byte[] a, byte[] b) {
|
||||
return dotProduct(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
|
||||
}
|
||||
|
||||
public static int dotProduct(MemorySegment a, MemorySegment b) {
|
||||
assert a.byteSize() == b.byteSize();
|
||||
int i = 0;
|
||||
int res = 0;
|
||||
|
||||
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
|
||||
// vectors (256-bit on intel to dodge performance landmines)
|
||||
if (a.length >= 16 && HAS_FAST_INTEGER_VECTORS) {
|
||||
if (a.byteSize() >= 16 && HAS_FAST_INTEGER_VECTORS) {
|
||||
// compute vectorized dot product consistent with VPDPBUSD instruction
|
||||
if (VECTOR_BITSIZE >= 512) {
|
||||
i += BYTE_SPECIES.loopBound(a.length);
|
||||
i += BYTE_SPECIES.loopBound(a.byteSize());
|
||||
res += dotProductBody512(a, b, i);
|
||||
} else if (VECTOR_BITSIZE == 256) {
|
||||
i += BYTE_SPECIES.loopBound(a.length);
|
||||
i += BYTE_SPECIES.loopBound(a.byteSize());
|
||||
res += dotProductBody256(a, b, i);
|
||||
} else {
|
||||
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
|
||||
i += ByteVector.SPECIES_64.loopBound(a.length - ByteVector.SPECIES_64.length());
|
||||
i += ByteVector.SPECIES_64.loopBound(a.byteSize() - ByteVector.SPECIES_64.length());
|
||||
res += dotProductBody128(a, b, i);
|
||||
}
|
||||
}
|
||||
|
||||
// scalar tail
|
||||
for (; i < a.length; i++) {
|
||||
res += b[i] * a[i];
|
||||
for (; i < a.byteSize(); i++) {
|
||||
res += b.get(JAVA_BYTE, i) * a.get(JAVA_BYTE, i);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/** vectorized dot product body (512 bit vectors) */
|
||||
private int dotProductBody512(byte[] a, byte[] b, int limit) {
|
||||
private static int dotProductBody512(MemorySegment a, MemorySegment b, int limit) {
|
||||
IntVector acc = IntVector.zero(INT_SPECIES);
|
||||
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
|
||||
ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN);
|
||||
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN);
|
||||
|
||||
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
|
||||
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
|
||||
|
@ -355,11 +363,11 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
}
|
||||
|
||||
/** vectorized dot product body (256 bit vectors) */
|
||||
private int dotProductBody256(byte[] a, byte[] b, int limit) {
|
||||
private static int dotProductBody256(MemorySegment a, MemorySegment b, int limit) {
|
||||
IntVector acc = IntVector.zero(IntVector.SPECIES_256);
|
||||
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
|
||||
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
|
||||
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
|
||||
|
||||
// 32-bit multiply and add into accumulator
|
||||
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0);
|
||||
|
@ -371,13 +379,13 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
}
|
||||
|
||||
/** vectorized dot product body (128 bit vectors) */
|
||||
private int dotProductBody128(byte[] a, byte[] b, int limit) {
|
||||
private static int dotProductBody128(MemorySegment a, MemorySegment b, int limit) {
|
||||
IntVector acc = IntVector.zero(IntVector.SPECIES_128);
|
||||
// 4 bytes at a time (re-loading half the vector each time!)
|
||||
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
|
||||
// load 8 bytes
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
|
||||
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
|
||||
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
|
||||
|
||||
// process first "half" only: 16-bit multiply
|
||||
Vector<Short> va16 = va8.convert(B2S, 0);
|
||||
|
@ -569,6 +577,10 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
|
||||
@Override
|
||||
public float cosine(byte[] a, byte[] b) {
|
||||
return cosine(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
|
||||
}
|
||||
|
||||
public static float cosine(MemorySegment a, MemorySegment b) {
|
||||
int i = 0;
|
||||
int sum = 0;
|
||||
int norm1 = 0;
|
||||
|
@ -576,17 +588,17 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
|
||||
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
|
||||
// vectors (256-bit on intel to dodge performance landmines)
|
||||
if (a.length >= 16 && HAS_FAST_INTEGER_VECTORS) {
|
||||
if (a.byteSize() >= 16 && HAS_FAST_INTEGER_VECTORS) {
|
||||
final float[] ret;
|
||||
if (VECTOR_BITSIZE >= 512) {
|
||||
i += BYTE_SPECIES.loopBound(a.length);
|
||||
i += BYTE_SPECIES.loopBound((int) a.byteSize());
|
||||
ret = cosineBody512(a, b, i);
|
||||
} else if (VECTOR_BITSIZE == 256) {
|
||||
i += BYTE_SPECIES.loopBound(a.length);
|
||||
i += BYTE_SPECIES.loopBound((int) a.byteSize());
|
||||
ret = cosineBody256(a, b, i);
|
||||
} else {
|
||||
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
|
||||
i += ByteVector.SPECIES_64.loopBound(a.length - ByteVector.SPECIES_64.length());
|
||||
i += ByteVector.SPECIES_64.loopBound(a.byteSize() - ByteVector.SPECIES_64.length());
|
||||
ret = cosineBody128(a, b, i);
|
||||
}
|
||||
sum += ret[0];
|
||||
|
@ -595,9 +607,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
}
|
||||
|
||||
// scalar tail
|
||||
for (; i < a.length; i++) {
|
||||
byte elem1 = a[i];
|
||||
byte elem2 = b[i];
|
||||
for (; i < a.byteSize(); i++) {
|
||||
byte elem1 = a.get(JAVA_BYTE, i);
|
||||
byte elem2 = b.get(JAVA_BYTE, i);
|
||||
sum += elem1 * elem2;
|
||||
norm1 += elem1 * elem1;
|
||||
norm2 += elem2 * elem2;
|
||||
|
@ -606,13 +618,13 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
}
|
||||
|
||||
/** vectorized cosine body (512 bit vectors) */
|
||||
private float[] cosineBody512(byte[] a, byte[] b, int limit) {
|
||||
private static float[] cosineBody512(MemorySegment a, MemorySegment b, int limit) {
|
||||
IntVector accSum = IntVector.zero(INT_SPECIES);
|
||||
IntVector accNorm1 = IntVector.zero(INT_SPECIES);
|
||||
IntVector accNorm2 = IntVector.zero(INT_SPECIES);
|
||||
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
|
||||
ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN);
|
||||
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN);
|
||||
|
||||
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
|
||||
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
|
||||
|
@ -636,13 +648,13 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
}
|
||||
|
||||
/** vectorized cosine body (256 bit vectors) */
|
||||
private float[] cosineBody256(byte[] a, byte[] b, int limit) {
|
||||
private static float[] cosineBody256(MemorySegment a, MemorySegment b, int limit) {
|
||||
IntVector accSum = IntVector.zero(IntVector.SPECIES_256);
|
||||
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_256);
|
||||
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_256);
|
||||
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
|
||||
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
|
||||
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
|
||||
|
||||
// 16-bit multiply, and add into accumulators
|
||||
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0);
|
||||
|
@ -661,13 +673,13 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
}
|
||||
|
||||
/** vectorized cosine body (128 bit vectors) */
|
||||
private float[] cosineBody128(byte[] a, byte[] b, int limit) {
|
||||
private static float[] cosineBody128(MemorySegment a, MemorySegment b, int limit) {
|
||||
IntVector accSum = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_128);
|
||||
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
|
||||
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
|
||||
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
|
||||
|
||||
// process first half only: 16-bit multiply
|
||||
Vector<Short> va16 = va8.convert(B2S, 0);
|
||||
|
@ -689,35 +701,40 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
|
||||
@Override
|
||||
public int squareDistance(byte[] a, byte[] b) {
|
||||
return squareDistance(MemorySegment.ofArray(a), MemorySegment.ofArray(b));
|
||||
}
|
||||
|
||||
public static int squareDistance(MemorySegment a, MemorySegment b) {
|
||||
assert a.byteSize() == b.byteSize();
|
||||
int i = 0;
|
||||
int res = 0;
|
||||
|
||||
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
|
||||
// vectors (256-bit on intel to dodge performance landmines)
|
||||
if (a.length >= 16 && HAS_FAST_INTEGER_VECTORS) {
|
||||
if (a.byteSize() >= 16 && HAS_FAST_INTEGER_VECTORS) {
|
||||
if (VECTOR_BITSIZE >= 256) {
|
||||
i += BYTE_SPECIES.loopBound(a.length);
|
||||
i += BYTE_SPECIES.loopBound((int) a.byteSize());
|
||||
res += squareDistanceBody256(a, b, i);
|
||||
} else {
|
||||
i += ByteVector.SPECIES_64.loopBound(a.length);
|
||||
i += ByteVector.SPECIES_64.loopBound((int) a.byteSize());
|
||||
res += squareDistanceBody128(a, b, i);
|
||||
}
|
||||
}
|
||||
|
||||
// scalar tail
|
||||
for (; i < a.length; i++) {
|
||||
int diff = a[i] - b[i];
|
||||
for (; i < a.byteSize(); i++) {
|
||||
int diff = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i);
|
||||
res += diff * diff;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/** vectorized square distance body (256+ bit vectors) */
|
||||
private int squareDistanceBody256(byte[] a, byte[] b, int limit) {
|
||||
private static int squareDistanceBody256(MemorySegment a, MemorySegment b, int limit) {
|
||||
IntVector acc = IntVector.zero(INT_SPECIES);
|
||||
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
|
||||
ByteVector va8 = ByteVector.fromMemorySegment(BYTE_SPECIES, a, i, LITTLE_ENDIAN);
|
||||
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES, b, i, LITTLE_ENDIAN);
|
||||
|
||||
// 32-bit sub, multiply, and add into accumulators
|
||||
// TODO: uses AVX-512 heavy multiply on zmm, should we just use 256-bit vectors on AVX-512?
|
||||
|
@ -731,14 +748,14 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
}
|
||||
|
||||
/** vectorized square distance body (128 bit vectors) */
|
||||
private int squareDistanceBody128(byte[] a, byte[] b, int limit) {
|
||||
private static int squareDistanceBody128(MemorySegment a, MemorySegment b, int limit) {
|
||||
// 128-bit implementation, which must "split up" vectors due to widening conversions
|
||||
// it doesn't help to do the overlapping read trick, due to 32-bit multiply in the formula
|
||||
IntVector acc1 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector acc2 = IntVector.zero(IntVector.SPECIES_128);
|
||||
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
|
||||
ByteVector va8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, a, i, LITTLE_ENDIAN);
|
||||
ByteVector vb8 = ByteVector.fromMemorySegment(ByteVector.SPECIES_64, b, i, LITTLE_ENDIAN);
|
||||
|
||||
// 16-bit sub
|
||||
Vector<Short> va16 = va8.convertShape(B2S, ShortVector.SPECIES_128, 0);
|
||||
|
|
|
@ -21,6 +21,7 @@ import java.security.PrivilegedAction;
|
|||
import java.util.Locale;
|
||||
import java.util.logging.Logger;
|
||||
import jdk.incubator.vector.FloatVector;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.util.Constants;
|
||||
import org.apache.lucene.util.SuppressForbidden;
|
||||
|
||||
|
@ -73,4 +74,9 @@ final class PanamaVectorizationProvider extends VectorizationProvider {
|
|||
public VectorUtilSupport getVectorUtilSupport() {
|
||||
return vectorUtilSupport;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FlatVectorsScorer getLucene99FlatVectorsScorer() {
|
||||
return Lucene99MemorySegmentFlatVectorsScorer.INSTANCE;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.store;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.foreign.MemorySegment;
|
||||
|
||||
/**
|
||||
* Provides access to the backing memory segment.
|
||||
*
|
||||
* <p>Expert API, allows access to the backing memory.
|
||||
*/
|
||||
public interface MemorySegmentAccessInput extends RandomAccessInput, Cloneable {
|
||||
|
||||
/** Returns the memory segment for a given position and length, or null. */
|
||||
MemorySegment segmentSliceOrNull(long pos, int len) throws IOException;
|
||||
|
||||
MemorySegmentAccessInput clone();
|
||||
}
|
|
@ -36,7 +36,8 @@ import org.apache.lucene.util.GroupVIntUtil;
|
|||
* chunkSizePower</code>).
|
||||
*/
|
||||
@SuppressWarnings("preview")
|
||||
abstract class MemorySegmentIndexInput extends IndexInput implements RandomAccessInput {
|
||||
abstract class MemorySegmentIndexInput extends IndexInput
|
||||
implements RandomAccessInput, MemorySegmentAccessInput {
|
||||
static final ValueLayout.OfByte LAYOUT_BYTE = ValueLayout.JAVA_BYTE;
|
||||
static final ValueLayout.OfShort LAYOUT_LE_SHORT =
|
||||
ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
|
||||
|
@ -562,6 +563,10 @@ abstract class MemorySegmentIndexInput extends IndexInput implements RandomAcces
|
|||
}
|
||||
}
|
||||
|
||||
static boolean checkIndex(long index, long length) {
|
||||
return index >= 0 && index < length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void close() throws IOException {
|
||||
if (curSegment == null) {
|
||||
|
@ -673,6 +678,16 @@ abstract class MemorySegmentIndexInput extends IndexInput implements RandomAcces
|
|||
throw alreadyClosed(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public MemorySegment segmentSliceOrNull(long pos, int len) throws IOException {
|
||||
try {
|
||||
Objects.checkIndex(pos + len, this.length + 1);
|
||||
return curSegment.asSlice(pos, len);
|
||||
} catch (IndexOutOfBoundsException e) {
|
||||
throw handlePositionalIOOBE(e, "segmentSliceOrNull", pos);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** This class adds offset support to MemorySegmentIndexInput, which is needed for slices. */
|
||||
|
@ -738,6 +753,20 @@ abstract class MemorySegmentIndexInput extends IndexInput implements RandomAcces
|
|||
return super.readLong(pos + offset);
|
||||
}
|
||||
|
||||
public MemorySegment segmentSliceOrNull(long pos, int len) throws IOException {
|
||||
if (pos + len > length) {
|
||||
throw handlePositionalIOOBE(null, "segmentSliceOrNull", pos);
|
||||
}
|
||||
pos = pos + offset;
|
||||
final int si = (int) (pos >> chunkSizePower);
|
||||
final MemorySegment seg = segments[si];
|
||||
final long segOffset = pos & chunkSizeMask;
|
||||
if (checkIndex(segOffset + len, seg.byteSize() + 1)) {
|
||||
return seg.asSlice(segOffset, len);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
MemorySegmentIndexInput buildSlice(String sliceDescription, long ofs, long length) {
|
||||
return super.buildSlice(sliceDescription, this.offset + ofs, length);
|
||||
|
|
|
@ -21,6 +21,8 @@ import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
|
|||
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.oneOf;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
|
@ -60,8 +62,9 @@ public class TestFlatVectorScorer extends LuceneTestCase {
|
|||
public static Iterable<Object[]> parametersFactory() {
|
||||
var scorers =
|
||||
List.of(
|
||||
new DefaultFlatVectorScorer(),
|
||||
new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()));
|
||||
DefaultFlatVectorScorer.INSTANCE,
|
||||
new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()),
|
||||
FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
|
||||
var dirs =
|
||||
List.<ThrowingSupplier<Directory>>of(
|
||||
TestFlatVectorScorer::newDirectory,
|
||||
|
@ -76,7 +79,14 @@ public class TestFlatVectorScorer extends LuceneTestCase {
|
|||
return objs;
|
||||
}
|
||||
|
||||
// Tests that the creation of another scorer does not perturb previous scorers
|
||||
public void testDefaultOrMemSegScorer() {
|
||||
var scorer = FlatVectorScorerUtil.getLucene99FlatVectorsScorer();
|
||||
assertThat(
|
||||
scorer.toString(),
|
||||
is(oneOf("DefaultFlatVectorScorer()", "Lucene99MemorySegmentFlatVectorsScorer()")));
|
||||
}
|
||||
|
||||
// Tests that the creation of another scorer does not disturb previous scorers
|
||||
public void testMultipleByteScorers() throws IOException {
|
||||
byte[] vec0 = new byte[] {0, 0, 0, 0};
|
||||
byte[] vec1 = new byte[] {1, 1, 1, 1};
|
||||
|
|
|
@ -16,11 +16,15 @@
|
|||
*/
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import static java.lang.String.format;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.oneOf;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.FilterCodec;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
|
@ -219,9 +223,12 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
|
|||
10, 20, 1, (byte) 4, false, 0.9f, null);
|
||||
}
|
||||
};
|
||||
String expectedString =
|
||||
"Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer()), rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer())))";
|
||||
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
|
||||
String expectedPattern =
|
||||
"Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer()), rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s())))";
|
||||
var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
|
||||
var memSegScorer =
|
||||
format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
|
||||
assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
|
||||
}
|
||||
|
||||
public void testLimits() {
|
||||
|
|
|
@ -16,6 +16,11 @@
|
|||
*/
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import static java.lang.String.format;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.oneOf;
|
||||
|
||||
import java.util.Locale;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.FilterCodec;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
|
@ -37,9 +42,12 @@ public class TestLucene99HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
|
|||
return new Lucene99HnswVectorsFormat(10, 20);
|
||||
}
|
||||
};
|
||||
String expectedString =
|
||||
"Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer()))";
|
||||
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
|
||||
String expectedPattern =
|
||||
"Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))";
|
||||
var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
|
||||
var memSegScorer =
|
||||
format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
|
||||
assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
|
||||
}
|
||||
|
||||
public void testLimits() {
|
||||
|
|
|
@ -0,0 +1,398 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.internal.vectorization;
|
||||
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.generators.RandomNumbers;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.IntStream;
|
||||
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.store.MMapDirectory;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.NamedThreadFactory;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.junit.BeforeClass;
|
||||
|
||||
public class TestVectorScorer extends LuceneTestCase {
|
||||
|
||||
private static final double DELTA = 1e-5;
|
||||
|
||||
static final FlatVectorsScorer DEFAULT_SCORER = DefaultFlatVectorScorer.INSTANCE;
|
||||
static final FlatVectorsScorer MEMSEG_SCORER =
|
||||
VectorizationProvider.lookup(true).getLucene99FlatVectorsScorer();
|
||||
|
||||
@BeforeClass
|
||||
public static void beforeClass() throws Exception {
|
||||
assumeTrue(
|
||||
"Test only works when the Memory segment scorer is present.",
|
||||
MEMSEG_SCORER.getClass() != DEFAULT_SCORER.getClass());
|
||||
}
|
||||
|
||||
public void testSimpleScorer() throws IOException {
|
||||
testSimpleScorer(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE);
|
||||
}
|
||||
|
||||
public void testSimpleScorerSmallChunkSize() throws IOException {
|
||||
long maxChunkSize = random().nextLong(4, 16);
|
||||
testSimpleScorer(maxChunkSize);
|
||||
}
|
||||
|
||||
public void testSimpleScorerMedChunkSize() throws IOException {
|
||||
// a chunk size where in some vectors will be copied on-heap, while others remain off-heap
|
||||
testSimpleScorer(64);
|
||||
}
|
||||
|
||||
void testSimpleScorer(long maxChunkSize) throws IOException {
|
||||
try (Directory dir = new MMapDirectory(createTempDir("testSimpleScorer"), maxChunkSize)) {
|
||||
for (int dims : List.of(31, 32, 33)) {
|
||||
// dimensions that, in some scenarios, cross the mmap chunk sizes
|
||||
byte[][] vectors = new byte[2][dims];
|
||||
String fileName = "bar-" + dims;
|
||||
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
|
||||
for (int i = 0; i < dims; i++) {
|
||||
vectors[0][i] = (byte) i;
|
||||
vectors[1][i] = (byte) (dims - i);
|
||||
}
|
||||
byte[] bytes = concat(vectors[0], vectors[1]);
|
||||
out.writeBytes(bytes, 0, bytes.length);
|
||||
}
|
||||
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
|
||||
for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) {
|
||||
var vectorValues = vectorValues(dims, 2, in, sim);
|
||||
for (var ords : List.of(List.of(0, 1), List.of(1, 0))) {
|
||||
int idx0 = ords.get(0);
|
||||
int idx1 = ords.get(1);
|
||||
|
||||
// getRandomVectorScorerSupplier
|
||||
var scorer1 = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
|
||||
float expected = scorer1.scorer(idx0).score(idx1);
|
||||
var scorer2 = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
|
||||
assertEquals(scorer2.scorer(idx0).score(idx1), expected, DELTA);
|
||||
|
||||
// getRandomVectorScorer
|
||||
var scorer3 = DEFAULT_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]);
|
||||
assertEquals(scorer3.score(idx1), expected, DELTA);
|
||||
var scorer4 = MEMSEG_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]);
|
||||
assertEquals(scorer4.score(idx1), expected, DELTA);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testRandomScorer() throws IOException {
|
||||
testRandomScorer(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_RANDOM_FUNC);
|
||||
}
|
||||
|
||||
public void testRandomScorerMax() throws IOException {
|
||||
testRandomScorer(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MAX_FUNC);
|
||||
}
|
||||
|
||||
public void testRandomScorerMin() throws IOException {
|
||||
testRandomScorer(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MIN_FUNC);
|
||||
}
|
||||
|
||||
public void testRandomSmallChunkSize() throws IOException {
|
||||
long maxChunkSize = randomLongBetween(32, 128);
|
||||
testRandomScorer(maxChunkSize, BYTE_ARRAY_RANDOM_FUNC);
|
||||
}
|
||||
|
||||
void testRandomScorer(long maxChunkSize, Function<Integer, byte[]> byteArraySupplier)
|
||||
throws IOException {
|
||||
try (Directory dir = new MMapDirectory(createTempDir("testRandomScorer"), maxChunkSize)) {
|
||||
final int dims = randomIntBetween(1, 4096);
|
||||
final int size = randomIntBetween(2, 100);
|
||||
final byte[][] vectors = new byte[size][];
|
||||
String fileName = "foo-" + dims;
|
||||
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
var vec = byteArraySupplier.apply(dims);
|
||||
out.writeBytes(vec, 0, vec.length);
|
||||
vectors[i] = vec;
|
||||
}
|
||||
}
|
||||
|
||||
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
|
||||
for (int times = 0; times < TIMES; times++) {
|
||||
for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) {
|
||||
var vectorValues = vectorValues(dims, size, in, sim);
|
||||
int idx0 = randomIntBetween(0, size - 1);
|
||||
int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok.
|
||||
|
||||
// getRandomVectorScorerSupplier
|
||||
var scorer1 = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
|
||||
float expected = scorer1.scorer(idx0).score(idx1);
|
||||
var scorer2 = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
|
||||
assertEquals(scorer2.scorer(idx0).score(idx1), expected, DELTA);
|
||||
|
||||
// getRandomVectorScorer
|
||||
var scorer3 = DEFAULT_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]);
|
||||
assertEquals(scorer3.score(idx1), expected, DELTA);
|
||||
var scorer4 = MEMSEG_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]);
|
||||
assertEquals(scorer4.score(idx1), expected, DELTA);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testRandomSliceSmall() throws IOException {
|
||||
testRandomSliceImpl(30, 64, 1, BYTE_ARRAY_RANDOM_FUNC);
|
||||
}
|
||||
|
||||
public void testRandomSlice() throws IOException {
|
||||
int dims = randomIntBetween(1, 4096);
|
||||
long maxChunkSize = randomLongBetween(32, 128);
|
||||
int initialOffset = randomIntBetween(1, 129);
|
||||
testRandomSliceImpl(dims, maxChunkSize, initialOffset, BYTE_ARRAY_RANDOM_FUNC);
|
||||
}
|
||||
|
||||
// Tests with a slice that has a non-zero initial offset
|
||||
void testRandomSliceImpl(
|
||||
int dims, long maxChunkSize, int initialOffset, Function<Integer, byte[]> byteArraySupplier)
|
||||
throws IOException {
|
||||
try (Directory dir = new MMapDirectory(createTempDir("testRandomSliceImpl"), maxChunkSize)) {
|
||||
final int size = randomIntBetween(2, 100);
|
||||
final byte[][] vectors = new byte[size][];
|
||||
String fileName = "baz-" + dims;
|
||||
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
|
||||
byte[] ba = new byte[initialOffset];
|
||||
out.writeBytes(ba, 0, ba.length);
|
||||
for (int i = 0; i < size; i++) {
|
||||
var vec = byteArraySupplier.apply(dims);
|
||||
out.writeBytes(vec, 0, vec.length);
|
||||
vectors[i] = vec;
|
||||
}
|
||||
}
|
||||
|
||||
try (var outter = dir.openInput(fileName, IOContext.DEFAULT);
|
||||
var in = outter.slice("slice", initialOffset, outter.length() - initialOffset)) {
|
||||
for (int times = 0; times < TIMES; times++) {
|
||||
for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) {
|
||||
var vectorValues = vectorValues(dims, size, in, sim);
|
||||
int idx0 = randomIntBetween(0, size - 1);
|
||||
int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok.
|
||||
|
||||
// getRandomVectorScorerSupplier
|
||||
var scorer1 = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
|
||||
float expected = scorer1.scorer(idx0).score(idx1);
|
||||
var scorer2 = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
|
||||
assertEquals(scorer2.scorer(idx0).score(idx1), expected, DELTA);
|
||||
|
||||
// getRandomVectorScorer
|
||||
var scorer3 = DEFAULT_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]);
|
||||
assertEquals(scorer3.score(idx1), expected, DELTA);
|
||||
var scorer4 = MEMSEG_SCORER.getRandomVectorScorer(sim, vectorValues, vectors[idx0]);
|
||||
assertEquals(scorer4.score(idx1), expected, DELTA);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that copies in threads do not interfere with each other
|
||||
public void testCopiesAcrossThreads() throws Exception {
|
||||
final long maxChunkSize = 32;
|
||||
final int dims = 34; // dimensions that are larger than the chunk size, to force fallback
|
||||
byte[] vec1 = new byte[dims];
|
||||
byte[] vec2 = new byte[dims];
|
||||
IntStream.range(0, dims).forEach(i -> vec1[i] = 1);
|
||||
IntStream.range(0, dims).forEach(i -> vec2[i] = 2);
|
||||
try (Directory dir = new MMapDirectory(createTempDir("testRace"), maxChunkSize)) {
|
||||
String fileName = "biz-" + dims;
|
||||
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
|
||||
byte[] bytes = concat(vec1, vec1, vec2, vec2);
|
||||
out.writeBytes(bytes, 0, bytes.length);
|
||||
}
|
||||
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
|
||||
for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) {
|
||||
var vectorValues = vectorValues(dims, 4, in, sim);
|
||||
var scoreSupplier = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
|
||||
var expectedScore1 = scoreSupplier.scorer(0).score(1);
|
||||
var expectedScore2 = scoreSupplier.scorer(2).score(3);
|
||||
|
||||
var scorer = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
|
||||
var tasks =
|
||||
List.<Callable<Optional<Throwable>>>of(
|
||||
new AssertingScoreCallable(scorer.copy().scorer(0), 1, expectedScore1),
|
||||
new AssertingScoreCallable(scorer.copy().scorer(2), 3, expectedScore2));
|
||||
var executor = Executors.newFixedThreadPool(2, new NamedThreadFactory("copiesThreads"));
|
||||
var results = executor.invokeAll(tasks);
|
||||
executor.shutdown();
|
||||
assertTrue(executor.awaitTermination(30, TimeUnit.SECONDS));
|
||||
assertEquals(results.stream().filter(Predicate.not(Future::isDone)).count(), 0L);
|
||||
for (var res : results) {
|
||||
assertTrue("Unexpected exception" + res.get(), res.get().isEmpty());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A callable that scores the given ord and scorer and asserts the expected result.
|
||||
static class AssertingScoreCallable implements Callable<Optional<Throwable>> {
|
||||
final RandomVectorScorer scorer;
|
||||
final int ord;
|
||||
final float expectedScore;
|
||||
|
||||
AssertingScoreCallable(RandomVectorScorer scorer, int ord, float expectedScore) {
|
||||
this.scorer = scorer;
|
||||
this.ord = ord;
|
||||
this.expectedScore = expectedScore;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<Throwable> call() throws Exception {
|
||||
try {
|
||||
for (int i = 0; i < 100; i++) {
|
||||
assertEquals(scorer.score(ord), expectedScore, DELTA);
|
||||
}
|
||||
} catch (Throwable t) {
|
||||
return Optional.of(t);
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
|
||||
// Tests with a large amount of data (> 2GB), which ensures that data offsets do not overflow
|
||||
@Nightly
|
||||
public void testLarge() throws IOException {
|
||||
try (Directory dir = new MMapDirectory(createTempDir("testLarge"))) {
|
||||
final int dims = 8192;
|
||||
final int size = 262500;
|
||||
final String fileName = "large-" + dims;
|
||||
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
var vec = vector(i, dims);
|
||||
out.writeBytes(vec, 0, vec.length);
|
||||
}
|
||||
}
|
||||
|
||||
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
|
||||
assert in.length() > Integer.MAX_VALUE;
|
||||
for (int times = 0; times < TIMES; times++) {
|
||||
for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) {
|
||||
var vectorValues = vectorValues(dims, size, in, sim);
|
||||
int ord1 = randomIntBetween(0, size - 1);
|
||||
int ord2 = size - 1;
|
||||
for (var ords : List.of(List.of(ord1, ord2), List.of(ord2, ord1))) {
|
||||
int idx0 = ords.getFirst();
|
||||
int idx1 = ords.getLast();
|
||||
|
||||
// getRandomVectorScorerSupplier
|
||||
var scorer1 = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
|
||||
float expected = scorer1.scorer(idx0).score(idx1);
|
||||
var scorer2 = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
|
||||
assertEquals(scorer2.scorer(idx0).score(idx1), expected, DELTA);
|
||||
|
||||
// getRandomVectorScorer
|
||||
var query = vector(idx0, dims);
|
||||
var scorer3 = DEFAULT_SCORER.getRandomVectorScorer(sim, vectorValues, query);
|
||||
assertEquals(scorer3.score(idx1), expected, DELTA);
|
||||
var scorer4 = MEMSEG_SCORER.getRandomVectorScorer(sim, vectorValues, query);
|
||||
assertEquals(scorer4.score(idx1), expected, DELTA);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RandomAccessVectorValues vectorValues(
|
||||
int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
|
||||
return new OffHeapByteVectorValues.DenseOffHeapVectorValues(
|
||||
dims, size, in.slice("byteValues", 0, in.length()), dims, MEMSEG_SCORER, sim);
|
||||
}
|
||||
|
||||
// creates the vector based on the given ordinal, which is reproducible given the ord and dims
|
||||
static byte[] vector(int ord, int dims) {
|
||||
var random = new Random(Objects.hash(ord, dims));
|
||||
byte[] ba = new byte[dims];
|
||||
for (int i = 0; i < dims; i++) {
|
||||
ba[i] = (byte) RandomNumbers.randomIntBetween(random, Byte.MIN_VALUE, Byte.MAX_VALUE);
|
||||
}
|
||||
return ba;
|
||||
}
|
||||
|
||||
/** Concatenates byte arrays. */
|
||||
static byte[] concat(byte[]... arrays) throws IOException {
|
||||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
|
||||
for (var ba : arrays) {
|
||||
baos.write(ba);
|
||||
}
|
||||
return baos.toByteArray();
|
||||
}
|
||||
}
|
||||
|
||||
static int randomIntBetween(int minInclusive, int maxInclusive) {
|
||||
return RandomNumbers.randomIntBetween(random(), minInclusive, maxInclusive);
|
||||
}
|
||||
|
||||
static long randomLongBetween(long minInclusive, long maxInclusive) {
|
||||
return RandomNumbers.randomLongBetween(random(), minInclusive, maxInclusive);
|
||||
}
|
||||
|
||||
static Function<Integer, byte[]> BYTE_ARRAY_RANDOM_FUNC =
|
||||
size -> {
|
||||
byte[] ba = new byte[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
ba[i] = (byte) random().nextInt();
|
||||
}
|
||||
return ba;
|
||||
};
|
||||
|
||||
static Function<Integer, byte[]> BYTE_ARRAY_MAX_FUNC =
|
||||
size -> {
|
||||
byte[] ba = new byte[size];
|
||||
Arrays.fill(ba, Byte.MAX_VALUE);
|
||||
return ba;
|
||||
};
|
||||
|
||||
static Function<Integer, byte[]> BYTE_ARRAY_MIN_FUNC =
|
||||
size -> {
|
||||
byte[] ba = new byte[size];
|
||||
Arrays.fill(ba, Byte.MIN_VALUE);
|
||||
return ba;
|
||||
};
|
||||
|
||||
static final int TIMES = 100; // a loop iteration times
|
||||
}
|
|
@ -50,6 +50,7 @@ import org.apache.lucene.store.Directory;
|
|||
import org.apache.lucene.tests.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.tests.codecs.asserting.AssertingCodec;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.store.BaseDirectoryWrapper;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
|
@ -77,6 +78,13 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
|
||||
abstract Field getKnnVectorField(String name, float[] vector);
|
||||
|
||||
/**
|
||||
* Creates a new directory. Subclasses can override to test different directory implementations.
|
||||
*/
|
||||
protected BaseDirectoryWrapper newDirectoryForTest() {
|
||||
return LuceneTestCase.newDirectory(random());
|
||||
}
|
||||
|
||||
public void testEquals() {
|
||||
AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10);
|
||||
Query filter1 = new TermQuery(new Term("id", "id1"));
|
||||
|
@ -308,7 +316,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
}
|
||||
|
||||
public void testScoreCosine() throws IOException {
|
||||
try (Directory d = newDirectory()) {
|
||||
try (Directory d = newDirectoryForTest()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
for (int j = 1; j <= 5; j++) {
|
||||
Document doc = new Document();
|
||||
|
@ -383,7 +391,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
}
|
||||
|
||||
public void testExplain() throws IOException {
|
||||
try (Directory d = newDirectory()) {
|
||||
try (Directory d = newDirectoryForTest()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
for (int j = 0; j < 5; j++) {
|
||||
Document doc = new Document();
|
||||
|
@ -410,7 +418,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
}
|
||||
|
||||
public void testExplainMultipleSegments() throws IOException {
|
||||
try (Directory d = newDirectory()) {
|
||||
try (Directory d = newDirectoryForTest()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
for (int j = 0; j < 5; j++) {
|
||||
Document doc = new Document();
|
||||
|
@ -443,7 +451,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
* number of top K documents, but no more than K documents in total (otherwise we might occasionally
|
||||
* randomly fail to find one).
|
||||
*/
|
||||
try (Directory d = newDirectory()) {
|
||||
try (Directory d = newDirectoryForTest()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
int r = 0;
|
||||
for (int i = 0; i < 5; i++) {
|
||||
|
@ -479,7 +487,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
int dimension = atLeast(5);
|
||||
int numIters = atLeast(10);
|
||||
boolean everyDocHasAVector = random().nextBoolean();
|
||||
try (Directory d = newDirectory()) {
|
||||
try (Directory d = newDirectoryForTest()) {
|
||||
RandomIndexWriter w = new RandomIndexWriter(random(), d);
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
Document doc = new Document();
|
||||
|
@ -518,7 +526,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
int numDocs = 1000;
|
||||
int dimension = atLeast(5);
|
||||
int numIters = atLeast(10);
|
||||
try (Directory d = newDirectory()) {
|
||||
try (Directory d = newDirectoryForTest()) {
|
||||
// Always use the default kNN format to have predictable behavior around when it hits
|
||||
// visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
|
||||
// format
|
||||
|
@ -604,7 +612,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
public void testFilterWithSameScore() throws IOException {
|
||||
int numDocs = 100;
|
||||
int dimension = atLeast(5);
|
||||
try (Directory d = newDirectory()) {
|
||||
try (Directory d = newDirectoryForTest()) {
|
||||
// Always use the default kNN format to have predictable behavior around when it hits
|
||||
// visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN
|
||||
// format
|
||||
|
@ -644,7 +652,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
}
|
||||
|
||||
public void testDeletes() throws IOException {
|
||||
try (Directory dir = newDirectory();
|
||||
try (Directory dir = newDirectoryForTest();
|
||||
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
final int numDocs = atLeast(100);
|
||||
final int dim = 30;
|
||||
|
@ -688,7 +696,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
}
|
||||
|
||||
public void testAllDeletes() throws IOException {
|
||||
try (Directory dir = newDirectory();
|
||||
try (Directory dir = newDirectoryForTest();
|
||||
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
final int numDocs = atLeast(100);
|
||||
final int dim = 30;
|
||||
|
@ -717,7 +725,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
*/
|
||||
public void testNoLiveDocsReader() throws IOException {
|
||||
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||
try (Directory dir = newDirectory();
|
||||
try (Directory dir = newDirectoryForTest();
|
||||
IndexWriter w = new IndexWriter(dir, iwc)) {
|
||||
final int numDocs = 10;
|
||||
final int dim = 30;
|
||||
|
@ -745,7 +753,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
*/
|
||||
public void testBitSetQuery() throws IOException {
|
||||
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||
try (Directory dir = newDirectory();
|
||||
try (Directory dir = newDirectoryForTest();
|
||||
IndexWriter w = new IndexWriter(dir, iwc)) {
|
||||
final int numDocs = 100;
|
||||
final int dim = 30;
|
||||
|
@ -853,7 +861,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
Directory getIndexStore(
|
||||
String field, VectorSimilarityFunction vectorSimilarityFunction, float[]... contents)
|
||||
throws IOException {
|
||||
Directory indexStore = newDirectory();
|
||||
Directory indexStore = newDirectoryForTest();
|
||||
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
|
||||
for (int i = 0; i < contents.length; ++i) {
|
||||
Document doc = new Document();
|
||||
|
@ -886,7 +894,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
* preserving the order of the added documents.
|
||||
*/
|
||||
private Directory getStableIndexStore(String field, float[]... contents) throws IOException {
|
||||
Directory indexStore = newDirectory();
|
||||
Directory indexStore = newDirectoryForTest();
|
||||
try (IndexWriter writer = new IndexWriter(indexStore, new IndexWriterConfig())) {
|
||||
for (int i = 0; i < contents.length; ++i) {
|
||||
Document doc = new Document();
|
||||
|
@ -1031,7 +1039,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
}
|
||||
|
||||
public void testSameFieldDifferentFormats() throws IOException {
|
||||
try (Directory directory = newDirectory()) {
|
||||
try (Directory directory = newDirectoryForTest()) {
|
||||
MockAnalyzer mockAnalyzer = new MockAnalyzer(random());
|
||||
IndexWriterConfig iwc = newIndexWriterConfig(mockAnalyzer);
|
||||
KnnVectorsFormat format1 = randomVectorFormat(VectorEncoding.FLOAT32);
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import org.apache.lucene.store.MMapDirectory;
|
||||
import org.apache.lucene.tests.store.BaseDirectoryWrapper;
|
||||
import org.apache.lucene.tests.store.MockDirectoryWrapper;
|
||||
|
||||
public class TestKnnByteVectorQueryMMap extends TestKnnByteVectorQuery {
|
||||
|
||||
@Override
|
||||
protected BaseDirectoryWrapper newDirectoryForTest() {
|
||||
try {
|
||||
return new MockDirectoryWrapper(
|
||||
random(), new MMapDirectory(createTempDir("TestKnnByteVectorQueryMMap")));
|
||||
} catch (IOException e) {
|
||||
throw new UncheckedIOException(e);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ import java.io.Closeable;
|
|||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import org.apache.lucene.internal.tests.TestSecrets;
|
||||
import org.apache.lucene.store.FilterIndexInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
|
||||
|
@ -27,6 +28,11 @@ import org.apache.lucene.store.IndexInput;
|
|||
* Used by MockDirectoryWrapper to create an input stream that keeps track of when it's been closed.
|
||||
*/
|
||||
public class MockIndexInputWrapper extends FilterIndexInput {
|
||||
|
||||
static {
|
||||
TestSecrets.getFilterInputIndexAccess().addTestFilterType(MockIndexInputWrapper.class);
|
||||
}
|
||||
|
||||
private MockDirectoryWrapper dir;
|
||||
final String name;
|
||||
private volatile boolean closed;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.apache.lucene.tests.store;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.internal.tests.TestSecrets;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.SuppressForbidden;
|
||||
import org.apache.lucene.util.ThreadInterruptedException;
|
||||
|
@ -28,6 +29,11 @@ import org.apache.lucene.util.ThreadInterruptedException;
|
|||
*/
|
||||
class SlowClosingMockIndexInputWrapper extends MockIndexInputWrapper {
|
||||
|
||||
static {
|
||||
TestSecrets.getFilterInputIndexAccess()
|
||||
.addTestFilterType(SlowClosingMockIndexInputWrapper.class);
|
||||
}
|
||||
|
||||
public SlowClosingMockIndexInputWrapper(
|
||||
MockDirectoryWrapper dir, String name, IndexInput delegate) {
|
||||
super(dir, name, delegate, null);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.apache.lucene.tests.store;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.internal.tests.TestSecrets;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.SuppressForbidden;
|
||||
import org.apache.lucene.util.ThreadInterruptedException;
|
||||
|
@ -27,6 +28,11 @@ import org.apache.lucene.util.ThreadInterruptedException;
|
|||
*/
|
||||
class SlowOpeningMockIndexInputWrapper extends MockIndexInputWrapper {
|
||||
|
||||
static {
|
||||
TestSecrets.getFilterInputIndexAccess()
|
||||
.addTestFilterType(SlowOpeningMockIndexInputWrapper.class);
|
||||
}
|
||||
|
||||
@SuppressForbidden(reason = "Thread sleep")
|
||||
public SlowOpeningMockIndexInputWrapper(
|
||||
MockDirectoryWrapper dir, String name, IndexInput delegate) throws IOException {
|
||||
|
|
Loading…
Reference in New Issue