mirror of https://github.com/apache/lucene.git
Leverage accelerated vector hardware instructions in Vector Search. Lucene already has a mechanism that enables the use of non-final JDK APIs, currently used for the Previewing Pamana Foreign API. This change expands this mechanism to include the Incubating Pamana Vector API. When the jdk.incubator.vector module is present at run time the Panamaized version of the low-level primitives used by Vector Search is enabled. If not present, the default scalar version of these low-level primitives is used (as it was previously). Currently, we're only targeting support for JDK 20. A subsequent PR should evaluate JDK 21. --------- Co-authored-by: Uwe Schindler <uschindler@apache.org> Co-authored-by: Robert Muir <rmuir@apache.org>
This commit is contained in:
parent
7db2c128ab
commit
0b53670411
|
@ -119,7 +119,7 @@ apply from: file('gradle/ide/eclipse.gradle')
|
|||
// (java, tests)
|
||||
apply from: file('gradle/java/folder-layout.gradle')
|
||||
apply from: file('gradle/java/javac.gradle')
|
||||
apply from: file('gradle/java/memorysegment-mrjar.gradle')
|
||||
apply from: file('gradle/java/core-mrjar.gradle')
|
||||
apply from: file('gradle/testing/defaults-tests.gradle')
|
||||
apply from: file('gradle/testing/randomization.gradle')
|
||||
apply from: file('gradle/testing/fail-on-no-tests.gradle')
|
||||
|
@ -158,7 +158,7 @@ apply from: file('gradle/generation/javacc.gradle')
|
|||
apply from: file('gradle/generation/forUtil.gradle')
|
||||
apply from: file('gradle/generation/antlr.gradle')
|
||||
apply from: file('gradle/generation/unicode-test-classes.gradle')
|
||||
apply from: file('gradle/generation/panama-foreign.gradle')
|
||||
apply from: file('gradle/generation/extract-jdk-apis.gradle')
|
||||
|
||||
apply from: file('gradle/datasets/external-datasets.gradle')
|
||||
|
||||
|
|
|
@ -17,10 +17,17 @@
|
|||
|
||||
def resources = scriptResources(buildscript)
|
||||
|
||||
configure(rootProject) {
|
||||
ext {
|
||||
// also change this in extractor tool: ExtractForeignAPI
|
||||
vectorIncubatorJavaVersions = [ JavaVersion.VERSION_20 ] as Set
|
||||
}
|
||||
}
|
||||
|
||||
configure(project(":lucene:core")) {
|
||||
ext {
|
||||
apijars = file('src/generated/jdk');
|
||||
panamaJavaVersions = [ 19, 20 ]
|
||||
mrjarJavaVersions = [ 19, 20 ]
|
||||
}
|
||||
|
||||
configurations {
|
||||
|
@ -31,9 +38,9 @@ configure(project(":lucene:core")) {
|
|||
apiextractor "org.ow2.asm:asm:${scriptDepVersions['asm']}"
|
||||
}
|
||||
|
||||
for (jdkVersion : panamaJavaVersions) {
|
||||
def task = tasks.create(name: "generatePanamaForeignApiJar${jdkVersion}", type: JavaExec) {
|
||||
description "Regenerate the API-only JAR file with public Panama Foreign API from JDK ${jdkVersion}"
|
||||
for (jdkVersion : mrjarJavaVersions) {
|
||||
def task = tasks.create(name: "generateJdkApiJar${jdkVersion}", type: JavaExec) {
|
||||
description "Regenerate the API-only JAR file with public Panama Foreign & Vector API from JDK ${jdkVersion}"
|
||||
group "generation"
|
||||
|
||||
javaLauncher = javaToolchains.launcherFor {
|
||||
|
@ -45,21 +52,21 @@ configure(project(":lucene:core")) {
|
|||
javaLauncher.get()
|
||||
return true
|
||||
} catch (Exception e) {
|
||||
logger.warn('Launcher for Java {} is not available; skipping regeneration of Panama Foreign API JAR.', jdkVersion)
|
||||
logger.warn('Launcher for Java {} is not available; skipping regeneration of Panama Foreign & Vector API JAR.', jdkVersion)
|
||||
logger.warn('Error: {}', e.cause?.message)
|
||||
logger.warn("Please make sure to point env 'JAVA{}_HOME' to exactly JDK version {} or enable Gradle toolchain auto-download.", jdkVersion, jdkVersion)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
classpath = configurations.apiextractor
|
||||
mainClass = file("${resources}/ExtractForeignAPI.java") as String
|
||||
mainClass = file("${resources}/ExtractJdkApis.java") as String
|
||||
systemProperties = [
|
||||
'user.timezone': 'UTC'
|
||||
]
|
||||
args = [
|
||||
jdkVersion,
|
||||
new File(apijars, "panama-foreign-jdk${jdkVersion}.apijar"),
|
||||
new File(apijars, "jdk${jdkVersion}.apijar"),
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,196 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.PathMatcher;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.attribute.FileTime;
|
||||
import java.time.Instant;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.TreeMap;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipOutputStream;
|
||||
|
||||
import org.objectweb.asm.AnnotationVisitor;
|
||||
import org.objectweb.asm.ClassReader;
|
||||
import org.objectweb.asm.ClassVisitor;
|
||||
import org.objectweb.asm.ClassWriter;
|
||||
import org.objectweb.asm.FieldVisitor;
|
||||
import org.objectweb.asm.MethodVisitor;
|
||||
import org.objectweb.asm.Opcodes;
|
||||
import org.objectweb.asm.Type;
|
||||
|
||||
public final class ExtractJdkApis {
|
||||
|
||||
private static final FileTime FIXED_FILEDATE = FileTime.from(Instant.parse("2022-01-01T00:00:00Z"));
|
||||
|
||||
private static final String PATTERN_PANAMA_FOREIGN = "java.base/java/{lang/foreign/*,nio/channels/FileChannel,util/Objects}";
|
||||
private static final String PATTERN_VECTOR_INCUBATOR = "jdk.incubator.vector/jdk/incubator/vector/*";
|
||||
private static final String PATTERN_VECTOR_VM_INTERNALS = "java.base/jdk/internal/vm/vector/VectorSupport{,$Vector,$VectorMask,$VectorPayload,$VectorShuffle}";
|
||||
|
||||
static final Map<Integer,List<String>> CLASSFILE_PATTERNS = Map.of(
|
||||
19, List.of(PATTERN_PANAMA_FOREIGN),
|
||||
20, List.of(PATTERN_PANAMA_FOREIGN, PATTERN_VECTOR_VM_INTERNALS, PATTERN_VECTOR_INCUBATOR)
|
||||
);
|
||||
|
||||
public static void main(String... args) throws IOException {
|
||||
if (args.length != 2) {
|
||||
throw new IllegalArgumentException("Need two parameters: java version, output file");
|
||||
}
|
||||
Integer jdk = Integer.valueOf(args[0]);
|
||||
if (jdk.intValue() != Runtime.version().feature()) {
|
||||
throw new IllegalStateException("Incorrect java version: " + Runtime.version().feature());
|
||||
}
|
||||
if (!CLASSFILE_PATTERNS.containsKey(jdk)) {
|
||||
throw new IllegalArgumentException("No support to extract stubs from java version: " + jdk);
|
||||
}
|
||||
var outputPath = Paths.get(args[1]);
|
||||
|
||||
// create JRT filesystem and build a combined FileMatcher:
|
||||
var jrtPath = Paths.get(URI.create("jrt:/")).toRealPath();
|
||||
var patterns = CLASSFILE_PATTERNS.get(jdk).stream()
|
||||
.map(pattern -> jrtPath.getFileSystem().getPathMatcher("glob:" + pattern + ".class"))
|
||||
.toArray(PathMatcher[]::new);
|
||||
PathMatcher pattern = p -> Arrays.stream(patterns).anyMatch(matcher -> matcher.matches(p));
|
||||
|
||||
// Collect all files to process:
|
||||
final List<Path> filesToExtract;
|
||||
try (var stream = Files.walk(jrtPath)) {
|
||||
filesToExtract = stream.filter(p -> pattern.matches(jrtPath.relativize(p))).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
// Process all class files:
|
||||
try (var out = new ZipOutputStream(Files.newOutputStream(outputPath))) {
|
||||
process(filesToExtract, out);
|
||||
}
|
||||
}
|
||||
|
||||
private static void process(List<Path> filesToExtract, ZipOutputStream out) throws IOException {
|
||||
var classesToInclude = new HashSet<String>();
|
||||
var references = new HashMap<String, String[]>();
|
||||
var processed = new TreeMap<String, byte[]>();
|
||||
System.out.println("Transforming " + filesToExtract.size() + " class files...");
|
||||
for (Path p : filesToExtract) {
|
||||
try (var in = Files.newInputStream(p)) {
|
||||
var reader = new ClassReader(in);
|
||||
var cw = new ClassWriter(0);
|
||||
var cleaner = new Cleaner(cw, classesToInclude, references);
|
||||
reader.accept(cleaner, ClassReader.SKIP_CODE | ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
|
||||
processed.put(reader.getClassName(), cw.toByteArray());
|
||||
}
|
||||
}
|
||||
// recursively add all superclasses / interfaces of visible classes to classesToInclude:
|
||||
for (Set<String> a = classesToInclude; !a.isEmpty();) {
|
||||
a = a.stream().map(references::get).filter(Objects::nonNull).flatMap(Arrays::stream).collect(Collectors.toSet());
|
||||
classesToInclude.addAll(a);
|
||||
}
|
||||
// remove all non-visible or not referenced classes:
|
||||
processed.keySet().removeIf(Predicate.not(classesToInclude::contains));
|
||||
System.out.println("Writing " + processed.size() + " visible classes...");
|
||||
for (var cls : processed.entrySet()) {
|
||||
String cn = cls.getKey();
|
||||
System.out.println("Writing stub for class: " + cn);
|
||||
out.putNextEntry(new ZipEntry(cn.concat(".class")).setLastModifiedTime(FIXED_FILEDATE));
|
||||
out.write(cls.getValue());
|
||||
out.closeEntry();
|
||||
}
|
||||
classesToInclude.removeIf(processed.keySet()::contains);
|
||||
System.out.println("Referenced classes not included: " + classesToInclude);
|
||||
}
|
||||
|
||||
static boolean isVisible(int access) {
|
||||
return (access & (Opcodes.ACC_PROTECTED | Opcodes.ACC_PUBLIC)) != 0;
|
||||
}
|
||||
|
||||
static class Cleaner extends ClassVisitor {
|
||||
private static final String PREVIEW_ANN = "jdk/internal/javac/PreviewFeature";
|
||||
private static final String PREVIEW_ANN_DESCR = Type.getObjectType(PREVIEW_ANN).getDescriptor();
|
||||
|
||||
private final Set<String> classesToInclude;
|
||||
private final Map<String, String[]> references;
|
||||
|
||||
Cleaner(ClassWriter out, Set<String> classesToInclude, Map<String, String[]> references) {
|
||||
super(Opcodes.ASM9, out);
|
||||
this.classesToInclude = classesToInclude;
|
||||
this.references = references;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
|
||||
super.visit(Opcodes.V11, access, name, signature, superName, interfaces);
|
||||
if (isVisible(access)) {
|
||||
classesToInclude.add(name);
|
||||
}
|
||||
references.put(name, Stream.concat(Stream.of(superName), Arrays.stream(interfaces)).toArray(String[]::new));
|
||||
}
|
||||
|
||||
@Override
|
||||
public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
|
||||
return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) {
|
||||
if (!isVisible(access)) {
|
||||
return null;
|
||||
}
|
||||
return new FieldVisitor(Opcodes.ASM9, super.visitField(access, name, descriptor, signature, value)) {
|
||||
@Override
|
||||
public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
|
||||
return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
|
||||
if (!isVisible(access)) {
|
||||
return null;
|
||||
}
|
||||
return new MethodVisitor(Opcodes.ASM9, super.visitMethod(access, name, descriptor, signature, exceptions)) {
|
||||
@Override
|
||||
public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
|
||||
return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitInnerClass(String name, String outerName, String innerName, int access) {
|
||||
if (!Objects.equals(outerName, PREVIEW_ANN)) {
|
||||
super.visitInnerClass(name, outerName, innerName, access);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitPermittedSubclass(String c) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -1,132 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.attribute.FileTime;
|
||||
import java.time.Instant;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipOutputStream;
|
||||
|
||||
import org.objectweb.asm.AnnotationVisitor;
|
||||
import org.objectweb.asm.ClassReader;
|
||||
import org.objectweb.asm.ClassVisitor;
|
||||
import org.objectweb.asm.ClassWriter;
|
||||
import org.objectweb.asm.FieldVisitor;
|
||||
import org.objectweb.asm.MethodVisitor;
|
||||
import org.objectweb.asm.Opcodes;
|
||||
import org.objectweb.asm.Type;
|
||||
|
||||
public final class ExtractForeignAPI {
|
||||
|
||||
private static final FileTime FIXED_FILEDATE = FileTime.from(Instant.parse("2022-01-01T00:00:00Z"));
|
||||
|
||||
public static void main(String... args) throws IOException {
|
||||
if (args.length != 2) {
|
||||
throw new IllegalArgumentException("Need two parameters: java version, output file");
|
||||
}
|
||||
if (Integer.parseInt(args[0]) != Runtime.version().feature()) {
|
||||
throw new IllegalStateException("Incorrect java version: " + Runtime.version().feature());
|
||||
}
|
||||
var outputPath = Paths.get(args[1]);
|
||||
var javaBaseModule = Paths.get(URI.create("jrt:/")).resolve("java.base").toRealPath();
|
||||
var fileMatcher = javaBaseModule.getFileSystem().getPathMatcher("glob:java/{lang/foreign/*,nio/channels/FileChannel,util/Objects}.class");
|
||||
try (var out = new ZipOutputStream(Files.newOutputStream(outputPath)); var stream = Files.walk(javaBaseModule)) {
|
||||
var filesToExtract = stream.map(javaBaseModule::relativize).filter(fileMatcher::matches).sorted().collect(Collectors.toList());
|
||||
for (Path relative : filesToExtract) {
|
||||
System.out.println("Processing class file: " + relative);
|
||||
try (var in = Files.newInputStream(javaBaseModule.resolve(relative))) {
|
||||
final var reader = new ClassReader(in);
|
||||
final var cw = new ClassWriter(0);
|
||||
reader.accept(new Cleaner(cw), ClassReader.SKIP_CODE | ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
|
||||
out.putNextEntry(new ZipEntry(relative.toString()).setLastModifiedTime(FIXED_FILEDATE));
|
||||
out.write(cw.toByteArray());
|
||||
out.closeEntry();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static class Cleaner extends ClassVisitor {
|
||||
private static final String PREVIEW_ANN = "jdk/internal/javac/PreviewFeature";
|
||||
private static final String PREVIEW_ANN_DESCR = Type.getObjectType(PREVIEW_ANN).getDescriptor();
|
||||
|
||||
private boolean completelyHidden = false;
|
||||
|
||||
Cleaner(ClassWriter out) {
|
||||
super(Opcodes.ASM9, out);
|
||||
}
|
||||
|
||||
private boolean isHidden(int access) {
|
||||
return completelyHidden || (access & (Opcodes.ACC_PROTECTED | Opcodes.ACC_PUBLIC)) == 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
|
||||
super.visit(Opcodes.V11, access, name, signature, superName, interfaces);
|
||||
completelyHidden = isHidden(access);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
|
||||
return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) {
|
||||
if (isHidden(access)) {
|
||||
return null;
|
||||
}
|
||||
return new FieldVisitor(Opcodes.ASM9, super.visitField(access, name, descriptor, signature, value)) {
|
||||
@Override
|
||||
public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
|
||||
return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
|
||||
if (isHidden(access)) {
|
||||
return null;
|
||||
}
|
||||
return new MethodVisitor(Opcodes.ASM9, super.visitMethod(access, name, descriptor, signature, exceptions)) {
|
||||
@Override
|
||||
public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
|
||||
return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitInnerClass(String name, String outerName, String innerName, int access) {
|
||||
if (!Objects.equals(outerName, PREVIEW_ANN)) {
|
||||
super.visitInnerClass(name, outerName, innerName, access);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitPermittedSubclass(String c) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -15,11 +15,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Produce an MR-JAR with Java 19+ MemorySegment implementation for MMapDirectory
|
||||
// Produce an MR-JAR with Java 19+ foreign and vector implementations
|
||||
|
||||
configure(project(":lucene:core")) {
|
||||
plugins.withType(JavaPlugin) {
|
||||
for (jdkVersion : panamaJavaVersions) {
|
||||
for (jdkVersion : mrjarJavaVersions) {
|
||||
sourceSets.create("main${jdkVersion}") {
|
||||
java {
|
||||
srcDirs = ["src/java${jdkVersion}"]
|
||||
|
@ -29,7 +29,7 @@ configure(project(":lucene:core")) {
|
|||
dependencies.add("main${jdkVersion}Implementation", sourceSets.main.output)
|
||||
|
||||
tasks.named("compileMain${jdkVersion}Java").configure {
|
||||
def apijar = new File(apijars, "panama-foreign-jdk${jdkVersion}.apijar")
|
||||
def apijar = new File(apijars, "jdk${jdkVersion}.apijar")
|
||||
|
||||
inputs.file(apijar)
|
||||
|
||||
|
@ -40,12 +40,14 @@ configure(project(":lucene:core")) {
|
|||
"-Xlint:-options",
|
||||
"--patch-module", "java.base=${apijar}",
|
||||
"--add-exports", "java.base/java.lang.foreign=ALL-UNNAMED",
|
||||
// for compilation we patch the incubator packages into java.base, this has no effect on resulting class files:
|
||||
"--add-exports", "java.base/jdk.incubator.vector=ALL-UNNAMED",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
tasks.named('jar').configure {
|
||||
for (jdkVersion : panamaJavaVersions) {
|
||||
for (jdkVersion : mrjarJavaVersions) {
|
||||
into("META-INF/versions/${jdkVersion}") {
|
||||
from sourceSets["main${jdkVersion}"].output
|
||||
}
|
|
@ -47,7 +47,7 @@ allprojects {
|
|||
description: "Number of forked test JVMs"],
|
||||
[propName: 'tests.haltonfailure', value: true, description: "Halt processing on test failure."],
|
||||
[propName: 'tests.jvmargs',
|
||||
value: { -> propertyOrEnvOrDefault("tests.jvmargs", "TEST_JVM_ARGS", "-XX:TieredStopAtLevel=1 -XX:+UseParallelGC -XX:ActiveProcessorCount=1") },
|
||||
value: { -> propertyOrEnvOrDefault("tests.jvmargs", "TEST_JVM_ARGS", isCIBuild ? "" : "-XX:TieredStopAtLevel=1 -XX:+UseParallelGC -XX:ActiveProcessorCount=1") },
|
||||
description: "Arguments passed to each forked JVM."],
|
||||
// Other settings.
|
||||
[propName: 'tests.neverUpToDate', value: true,
|
||||
|
@ -119,11 +119,16 @@ allprojects {
|
|||
if (rootProject.runtimeJavaVersion < JavaVersion.VERSION_16) {
|
||||
jvmArgs '--illegal-access=deny'
|
||||
}
|
||||
|
||||
|
||||
// Lucene needs to optional modules at runtime, which we want to enforce for testing
|
||||
// (if the runner JVM does not support them, it will fail tests):
|
||||
jvmArgs '--add-modules', 'jdk.unsupported,jdk.management'
|
||||
|
||||
// Enable the vector incubator module on supported Java versions:
|
||||
if (rootProject.vectorIncubatorJavaVersions.contains(rootProject.runtimeJavaVersion)) {
|
||||
jvmArgs '--add-modules', 'jdk.incubator.vector'
|
||||
}
|
||||
|
||||
def loggingConfigFile = layout.projectDirectory.file("${resources}/logging.properties")
|
||||
def tempDir = layout.projectDirectory.dir(testsTmpDir.toString())
|
||||
jvmArgumentProviders.add(
|
||||
|
|
|
@ -20,6 +20,15 @@ New Features
|
|||
|
||||
* GITHUB#12257: Create OnHeapHnswGraphSearcher to let OnHeapHnswGraph to be searched in a thread-safety manner. (Patrick Zhai)
|
||||
|
||||
* GITHUB#12302, GITHUB#12311: Add vectorized implementations of VectorUtil.dotProduct(),
|
||||
squareDistance(), cosine() with Java 20 jdk.incubator.vector APIs. Applications started
|
||||
with command line parameter "java --add-modules jdk.incubator.vector" on exactly Java 20
|
||||
will automatically use the new vectorized implementations if running on a supported platform
|
||||
(x86 AVX2 or later, ARM SVE or later). This is an opt-in feature and requires explicit Java
|
||||
command line flag! When enabled, Lucene logs a notice using java.util.logging. Please test
|
||||
thoroughly and report bugs/slowness to Lucene's mailing list.
|
||||
(Chris Hegarty, Robert Muir, Uwe Schindler)
|
||||
|
||||
Improvements
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -40,4 +40,4 @@ to point the Lucene build system to missing JDK versions. The regeneration task
|
|||
a warning if a specific JDK is missing, leaving the already existing `.apijar` file
|
||||
untouched.
|
||||
|
||||
The extraction is done with the ASM library, see `ExtractForeignAPI.java` source code.
|
||||
The extraction is done with the ASM library, see `ExtractJdkApis.java` source code.
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -90,8 +90,8 @@ public enum VectorSimilarityFunction {
|
|||
|
||||
/**
|
||||
* Calculates a similarity score between the two vectors with a specified function. Higher
|
||||
* similarity scores correspond to closer vectors. The offsets and lengths of the BytesRefs
|
||||
* determine the vector data that is compared. Each (signed) byte represents a vector dimension.
|
||||
* similarity scores correspond to closer vectors. Each (signed) byte represents a vector
|
||||
* dimension.
|
||||
*
|
||||
* @param v1 a vector
|
||||
* @param v2 another vector, of the same dimension
|
||||
|
|
|
@ -20,6 +20,9 @@ package org.apache.lucene.util;
|
|||
/** Utilities for computations with numeric arrays */
|
||||
public final class VectorUtil {
|
||||
|
||||
// visible for testing
|
||||
static final VectorUtilProvider PROVIDER = VectorUtilProvider.lookup();
|
||||
|
||||
private VectorUtil() {}
|
||||
|
||||
/**
|
||||
|
@ -31,68 +34,7 @@ public final class VectorUtil {
|
|||
if (a.length != b.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
|
||||
}
|
||||
float res = 0f;
|
||||
/*
|
||||
* If length of vector is larger than 8, we use unrolled dot product to accelerate the
|
||||
* calculation.
|
||||
*/
|
||||
int i;
|
||||
for (i = 0; i < a.length % 8; i++) {
|
||||
res += b[i] * a[i];
|
||||
}
|
||||
if (a.length < 8) {
|
||||
return res;
|
||||
}
|
||||
for (; i + 31 < a.length; i += 32) {
|
||||
res +=
|
||||
b[i + 0] * a[i + 0]
|
||||
+ b[i + 1] * a[i + 1]
|
||||
+ b[i + 2] * a[i + 2]
|
||||
+ b[i + 3] * a[i + 3]
|
||||
+ b[i + 4] * a[i + 4]
|
||||
+ b[i + 5] * a[i + 5]
|
||||
+ b[i + 6] * a[i + 6]
|
||||
+ b[i + 7] * a[i + 7];
|
||||
res +=
|
||||
b[i + 8] * a[i + 8]
|
||||
+ b[i + 9] * a[i + 9]
|
||||
+ b[i + 10] * a[i + 10]
|
||||
+ b[i + 11] * a[i + 11]
|
||||
+ b[i + 12] * a[i + 12]
|
||||
+ b[i + 13] * a[i + 13]
|
||||
+ b[i + 14] * a[i + 14]
|
||||
+ b[i + 15] * a[i + 15];
|
||||
res +=
|
||||
b[i + 16] * a[i + 16]
|
||||
+ b[i + 17] * a[i + 17]
|
||||
+ b[i + 18] * a[i + 18]
|
||||
+ b[i + 19] * a[i + 19]
|
||||
+ b[i + 20] * a[i + 20]
|
||||
+ b[i + 21] * a[i + 21]
|
||||
+ b[i + 22] * a[i + 22]
|
||||
+ b[i + 23] * a[i + 23];
|
||||
res +=
|
||||
b[i + 24] * a[i + 24]
|
||||
+ b[i + 25] * a[i + 25]
|
||||
+ b[i + 26] * a[i + 26]
|
||||
+ b[i + 27] * a[i + 27]
|
||||
+ b[i + 28] * a[i + 28]
|
||||
+ b[i + 29] * a[i + 29]
|
||||
+ b[i + 30] * a[i + 30]
|
||||
+ b[i + 31] * a[i + 31];
|
||||
}
|
||||
for (; i + 7 < a.length; i += 8) {
|
||||
res +=
|
||||
b[i + 0] * a[i + 0]
|
||||
+ b[i + 1] * a[i + 1]
|
||||
+ b[i + 2] * a[i + 2]
|
||||
+ b[i + 3] * a[i + 3]
|
||||
+ b[i + 4] * a[i + 4]
|
||||
+ b[i + 5] * a[i + 5]
|
||||
+ b[i + 6] * a[i + 6]
|
||||
+ b[i + 7] * a[i + 7];
|
||||
}
|
||||
return res;
|
||||
return PROVIDER.dotProduct(a, b);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -100,42 +42,19 @@ public final class VectorUtil {
|
|||
*
|
||||
* @throws IllegalArgumentException if the vectors' dimensions differ.
|
||||
*/
|
||||
public static float cosine(float[] v1, float[] v2) {
|
||||
if (v1.length != v2.length) {
|
||||
throw new IllegalArgumentException(
|
||||
"vector dimensions differ: " + v1.length + "!=" + v2.length);
|
||||
public static float cosine(float[] a, float[] b) {
|
||||
if (a.length != b.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
float norm1 = 0.0f;
|
||||
float norm2 = 0.0f;
|
||||
int dim = v1.length;
|
||||
|
||||
for (int i = 0; i < dim; i++) {
|
||||
float elem1 = v1[i];
|
||||
float elem2 = v2[i];
|
||||
sum += elem1 * elem2;
|
||||
norm1 += elem1 * elem1;
|
||||
norm2 += elem2 * elem2;
|
||||
}
|
||||
return (float) (sum / Math.sqrt(norm1 * norm2));
|
||||
return PROVIDER.cosine(a, b);
|
||||
}
|
||||
|
||||
/** Returns the cosine similarity between the two vectors. */
|
||||
public static float cosine(byte[] a, byte[] b) {
|
||||
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
|
||||
int sum = 0;
|
||||
int norm1 = 0;
|
||||
int norm2 = 0;
|
||||
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
byte elem1 = a[i];
|
||||
byte elem2 = b[i];
|
||||
sum += elem1 * elem2;
|
||||
norm1 += elem1 * elem1;
|
||||
norm2 += elem2 * elem2;
|
||||
if (a.length != b.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
|
||||
}
|
||||
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
|
||||
return PROVIDER.cosine(a, b);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -143,52 +62,19 @@ public final class VectorUtil {
|
|||
*
|
||||
* @throws IllegalArgumentException if the vectors' dimensions differ.
|
||||
*/
|
||||
public static float squareDistance(float[] v1, float[] v2) {
|
||||
if (v1.length != v2.length) {
|
||||
throw new IllegalArgumentException(
|
||||
"vector dimensions differ: " + v1.length + "!=" + v2.length);
|
||||
public static float squareDistance(float[] a, float[] b) {
|
||||
if (a.length != b.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
|
||||
}
|
||||
float squareSum = 0.0f;
|
||||
int dim = v1.length;
|
||||
int i;
|
||||
for (i = 0; i + 8 <= dim; i += 8) {
|
||||
squareSum += squareDistanceUnrolled(v1, v2, i);
|
||||
}
|
||||
for (; i < dim; i++) {
|
||||
float diff = v1[i] - v2[i];
|
||||
squareSum += diff * diff;
|
||||
}
|
||||
return squareSum;
|
||||
}
|
||||
|
||||
private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
|
||||
float diff0 = v1[index + 0] - v2[index + 0];
|
||||
float diff1 = v1[index + 1] - v2[index + 1];
|
||||
float diff2 = v1[index + 2] - v2[index + 2];
|
||||
float diff3 = v1[index + 3] - v2[index + 3];
|
||||
float diff4 = v1[index + 4] - v2[index + 4];
|
||||
float diff5 = v1[index + 5] - v2[index + 5];
|
||||
float diff6 = v1[index + 6] - v2[index + 6];
|
||||
float diff7 = v1[index + 7] - v2[index + 7];
|
||||
return diff0 * diff0
|
||||
+ diff1 * diff1
|
||||
+ diff2 * diff2
|
||||
+ diff3 * diff3
|
||||
+ diff4 * diff4
|
||||
+ diff5 * diff5
|
||||
+ diff6 * diff6
|
||||
+ diff7 * diff7;
|
||||
return PROVIDER.squareDistance(a, b);
|
||||
}
|
||||
|
||||
/** Returns the sum of squared differences of the two vectors. */
|
||||
public static int squareDistance(byte[] a, byte[] b) {
|
||||
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
|
||||
int squareSum = 0;
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
int diff = a[i] - b[i];
|
||||
squareSum += diff * diff;
|
||||
if (a.length != b.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
|
||||
}
|
||||
return squareSum;
|
||||
return PROVIDER.squareDistance(a, b);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -250,12 +136,10 @@ public final class VectorUtil {
|
|||
* @return the value of the dot product of the two vectors
|
||||
*/
|
||||
public static int dotProduct(byte[] a, byte[] b) {
|
||||
assert a.length == b.length;
|
||||
int total = 0;
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
total += a[i] * b[i];
|
||||
if (a.length != b.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
|
||||
}
|
||||
return total;
|
||||
return PROVIDER.dotProduct(a, b);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util;
|
||||
|
||||
/** The default VectorUtil provider implementation. */
|
||||
final class VectorUtilDefaultProvider implements VectorUtilProvider {
|
||||
|
||||
VectorUtilDefaultProvider() {}
|
||||
|
||||
@Override
|
||||
public float dotProduct(float[] a, float[] b) {
|
||||
float res = 0f;
|
||||
/*
|
||||
* If length of vector is larger than 8, we use unrolled dot product to accelerate the
|
||||
* calculation.
|
||||
*/
|
||||
int i;
|
||||
for (i = 0; i < a.length % 8; i++) {
|
||||
res += b[i] * a[i];
|
||||
}
|
||||
if (a.length < 8) {
|
||||
return res;
|
||||
}
|
||||
for (; i + 31 < a.length; i += 32) {
|
||||
res +=
|
||||
b[i + 0] * a[i + 0]
|
||||
+ b[i + 1] * a[i + 1]
|
||||
+ b[i + 2] * a[i + 2]
|
||||
+ b[i + 3] * a[i + 3]
|
||||
+ b[i + 4] * a[i + 4]
|
||||
+ b[i + 5] * a[i + 5]
|
||||
+ b[i + 6] * a[i + 6]
|
||||
+ b[i + 7] * a[i + 7];
|
||||
res +=
|
||||
b[i + 8] * a[i + 8]
|
||||
+ b[i + 9] * a[i + 9]
|
||||
+ b[i + 10] * a[i + 10]
|
||||
+ b[i + 11] * a[i + 11]
|
||||
+ b[i + 12] * a[i + 12]
|
||||
+ b[i + 13] * a[i + 13]
|
||||
+ b[i + 14] * a[i + 14]
|
||||
+ b[i + 15] * a[i + 15];
|
||||
res +=
|
||||
b[i + 16] * a[i + 16]
|
||||
+ b[i + 17] * a[i + 17]
|
||||
+ b[i + 18] * a[i + 18]
|
||||
+ b[i + 19] * a[i + 19]
|
||||
+ b[i + 20] * a[i + 20]
|
||||
+ b[i + 21] * a[i + 21]
|
||||
+ b[i + 22] * a[i + 22]
|
||||
+ b[i + 23] * a[i + 23];
|
||||
res +=
|
||||
b[i + 24] * a[i + 24]
|
||||
+ b[i + 25] * a[i + 25]
|
||||
+ b[i + 26] * a[i + 26]
|
||||
+ b[i + 27] * a[i + 27]
|
||||
+ b[i + 28] * a[i + 28]
|
||||
+ b[i + 29] * a[i + 29]
|
||||
+ b[i + 30] * a[i + 30]
|
||||
+ b[i + 31] * a[i + 31];
|
||||
}
|
||||
for (; i + 7 < a.length; i += 8) {
|
||||
res +=
|
||||
b[i + 0] * a[i + 0]
|
||||
+ b[i + 1] * a[i + 1]
|
||||
+ b[i + 2] * a[i + 2]
|
||||
+ b[i + 3] * a[i + 3]
|
||||
+ b[i + 4] * a[i + 4]
|
||||
+ b[i + 5] * a[i + 5]
|
||||
+ b[i + 6] * a[i + 6]
|
||||
+ b[i + 7] * a[i + 7];
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float cosine(float[] a, float[] b) {
|
||||
float sum = 0.0f;
|
||||
float norm1 = 0.0f;
|
||||
float norm2 = 0.0f;
|
||||
int dim = a.length;
|
||||
|
||||
for (int i = 0; i < dim; i++) {
|
||||
float elem1 = a[i];
|
||||
float elem2 = b[i];
|
||||
sum += elem1 * elem2;
|
||||
norm1 += elem1 * elem1;
|
||||
norm2 += elem2 * elem2;
|
||||
}
|
||||
return (float) (sum / Math.sqrt(norm1 * norm2));
|
||||
}
|
||||
|
||||
@Override
|
||||
public float squareDistance(float[] a, float[] b) {
|
||||
float squareSum = 0.0f;
|
||||
int dim = a.length;
|
||||
int i;
|
||||
for (i = 0; i + 8 <= dim; i += 8) {
|
||||
squareSum += squareDistanceUnrolled(a, b, i);
|
||||
}
|
||||
for (; i < dim; i++) {
|
||||
float diff = a[i] - b[i];
|
||||
squareSum += diff * diff;
|
||||
}
|
||||
return squareSum;
|
||||
}
|
||||
|
||||
private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
|
||||
float diff0 = v1[index + 0] - v2[index + 0];
|
||||
float diff1 = v1[index + 1] - v2[index + 1];
|
||||
float diff2 = v1[index + 2] - v2[index + 2];
|
||||
float diff3 = v1[index + 3] - v2[index + 3];
|
||||
float diff4 = v1[index + 4] - v2[index + 4];
|
||||
float diff5 = v1[index + 5] - v2[index + 5];
|
||||
float diff6 = v1[index + 6] - v2[index + 6];
|
||||
float diff7 = v1[index + 7] - v2[index + 7];
|
||||
return diff0 * diff0
|
||||
+ diff1 * diff1
|
||||
+ diff2 * diff2
|
||||
+ diff3 * diff3
|
||||
+ diff4 * diff4
|
||||
+ diff5 * diff5
|
||||
+ diff6 * diff6
|
||||
+ diff7 * diff7;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dotProduct(byte[] a, byte[] b) {
|
||||
int total = 0;
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
total += a[i] * b[i];
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float cosine(byte[] a, byte[] b) {
|
||||
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
|
||||
int sum = 0;
|
||||
int norm1 = 0;
|
||||
int norm2 = 0;
|
||||
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
byte elem1 = a[i];
|
||||
byte elem2 = b[i];
|
||||
sum += elem1 * elem2;
|
||||
norm1 += elem1 * elem1;
|
||||
norm2 += elem2 * elem2;
|
||||
}
|
||||
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int squareDistance(byte[] a, byte[] b) {
|
||||
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
|
||||
int squareSum = 0;
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
int diff = a[i] - b[i];
|
||||
squareSum += diff * diff;
|
||||
}
|
||||
return squareSum;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,145 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util;
|
||||
|
||||
import java.lang.Runtime.Version;
|
||||
import java.lang.invoke.MethodHandles;
|
||||
import java.lang.invoke.MethodType;
|
||||
import java.security.AccessController;
|
||||
import java.security.PrivilegedAction;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/** A provider of VectorUtil implementations. */
|
||||
interface VectorUtilProvider {
|
||||
|
||||
/** Calculates the dot product of the given float arrays. */
|
||||
float dotProduct(float[] a, float[] b);
|
||||
|
||||
/** Returns the cosine similarity between the two vectors. */
|
||||
float cosine(float[] v1, float[] v2);
|
||||
|
||||
/** Returns the sum of squared differences of the two vectors. */
|
||||
float squareDistance(float[] a, float[] b);
|
||||
|
||||
/** Returns the dot product computed over signed bytes. */
|
||||
int dotProduct(byte[] a, byte[] b);
|
||||
|
||||
/** Returns the cosine similarity between the two byte vectors. */
|
||||
float cosine(byte[] a, byte[] b);
|
||||
|
||||
/** Returns the sum of squared differences of the two byte vectors. */
|
||||
int squareDistance(byte[] a, byte[] b);
|
||||
|
||||
// -- provider lookup mechanism
|
||||
|
||||
static final Logger LOG = Logger.getLogger(VectorUtilProvider.class.getName());
|
||||
|
||||
/** The minimal version of Java that has the bugfix for JDK-8301190. */
|
||||
static final Version VERSION_JDK8301190_FIXED = Version.parse("20.0.2");
|
||||
|
||||
static VectorUtilProvider lookup() {
|
||||
final int runtimeVersion = Runtime.version().feature();
|
||||
if (runtimeVersion == 20) {
|
||||
// is locale sane (only buggy in Java 20)
|
||||
if (isAffectedByJDK8301190()) {
|
||||
LOG.warning(
|
||||
"Java runtime is using a buggy default locale; Java vector incubator API can't be enabled: "
|
||||
+ Locale.getDefault());
|
||||
return new VectorUtilDefaultProvider();
|
||||
}
|
||||
// is the incubator module present and readable (JVM providers may to exclude them or it is
|
||||
// build with jlink)
|
||||
if (!vectorModulePresentAndReadable()) {
|
||||
LOG.warning(
|
||||
"Java vector incubator module is not readable. For optimal vector performance, pass '--add-modules jdk.incubator.vector' to enable Vector API.");
|
||||
return new VectorUtilDefaultProvider();
|
||||
}
|
||||
if (isClientVM()) {
|
||||
LOG.warning("C2 compiler is disabled; Java vector incubator API can't be enabled");
|
||||
return new VectorUtilDefaultProvider();
|
||||
}
|
||||
try {
|
||||
// we use method handles with lookup, so we do not need to deal with setAccessible as we
|
||||
// have private access through the lookup:
|
||||
final var lookup = MethodHandles.lookup();
|
||||
final var cls = lookup.findClass("org.apache.lucene.util.VectorUtilPanamaProvider");
|
||||
final var constr = lookup.findConstructor(cls, MethodType.methodType(void.class));
|
||||
try {
|
||||
return (VectorUtilProvider) constr.invoke();
|
||||
} catch (UnsupportedOperationException uoe) {
|
||||
// not supported because preferred vector size too small or similar
|
||||
LOG.warning("Java vector incubator API was not enabled. " + uoe.getMessage());
|
||||
return new VectorUtilDefaultProvider();
|
||||
} catch (RuntimeException | Error e) {
|
||||
throw e;
|
||||
} catch (Throwable th) {
|
||||
throw new AssertionError(th);
|
||||
}
|
||||
} catch (NoSuchMethodException | IllegalAccessException e) {
|
||||
throw new LinkageError(
|
||||
"VectorUtilPanamaProvider is missing correctly typed constructor", e);
|
||||
} catch (ClassNotFoundException cnfe) {
|
||||
throw new LinkageError("VectorUtilPanamaProvider is missing in Lucene JAR file", cnfe);
|
||||
}
|
||||
} else if (runtimeVersion >= 21) {
|
||||
LOG.warning(
|
||||
"You are running with Java 21 or later. To make full use of the Vector API, please update Apache Lucene.");
|
||||
}
|
||||
return new VectorUtilDefaultProvider();
|
||||
}
|
||||
|
||||
private static boolean vectorModulePresentAndReadable() {
|
||||
var opt =
|
||||
ModuleLayer.boot().modules().stream()
|
||||
.filter(m -> m.getName().equals("jdk.incubator.vector"))
|
||||
.findFirst();
|
||||
if (opt.isPresent()) {
|
||||
VectorUtilProvider.class.getModule().addReads(opt.get());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if runtime is affected by JDK-8301190 (avoids assertion when default language is say
|
||||
* "tr").
|
||||
*/
|
||||
private static boolean isAffectedByJDK8301190() {
|
||||
return VERSION_JDK8301190_FIXED.compareToIgnoreOptional(Runtime.version()) > 0
|
||||
&& !Objects.equals("I", "i".toUpperCase(Locale.getDefault()));
|
||||
}
|
||||
|
||||
@SuppressWarnings("removal")
|
||||
@SuppressForbidden(reason = "security manager")
|
||||
private static boolean isClientVM() {
|
||||
try {
|
||||
final PrivilegedAction<Boolean> action =
|
||||
() -> System.getProperty("java.vm.info", "").contains("emulated-client");
|
||||
return AccessController.doPrivileged(action);
|
||||
} catch (
|
||||
@SuppressWarnings("unused")
|
||||
SecurityException e) {
|
||||
LOG.warning(
|
||||
"SecurityManager denies permission to 'java.vm.info' system property, so state of C2 compiler can't be detected. "
|
||||
+ "In case of performance issues allow access to this property.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,477 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.util;
|
||||
|
||||
import java.util.logging.Logger;
|
||||
import jdk.incubator.vector.ByteVector;
|
||||
import jdk.incubator.vector.FloatVector;
|
||||
import jdk.incubator.vector.IntVector;
|
||||
import jdk.incubator.vector.ShortVector;
|
||||
import jdk.incubator.vector.Vector;
|
||||
import jdk.incubator.vector.VectorOperators;
|
||||
import jdk.incubator.vector.VectorShape;
|
||||
import jdk.incubator.vector.VectorSpecies;
|
||||
|
||||
/** A VectorUtil provider implementation that leverages the Panama Vector API. */
|
||||
final class VectorUtilPanamaProvider implements VectorUtilProvider {
|
||||
|
||||
/**
|
||||
* The bit size of the preferred species (this field is package private to allow the lookup to
|
||||
* load it).
|
||||
*/
|
||||
static final int INT_SPECIES_PREF_BIT_SIZE = IntVector.SPECIES_PREFERRED.vectorBitSize();
|
||||
|
||||
private static final VectorSpecies<Float> PREF_FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED;
|
||||
private static final VectorSpecies<Byte> PREF_BYTE_SPECIES;
|
||||
private static final VectorSpecies<Short> PREF_SHORT_SPECIES;
|
||||
|
||||
/**
|
||||
* x86 and less than 256-bit vectors.
|
||||
*
|
||||
* <p>it could be that it has only AVX1 and integer vectors are fast. it could also be that it has
|
||||
* no AVX and integer vectors are extremely slow. don't use integer vectors to avoid landmines.
|
||||
*/
|
||||
private static final boolean IS_AMD64_WITHOUT_AVX2 =
|
||||
Constants.OS_ARCH.equals("amd64") && INT_SPECIES_PREF_BIT_SIZE < 256;
|
||||
|
||||
static {
|
||||
if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
|
||||
PREF_BYTE_SPECIES =
|
||||
ByteVector.SPECIES_MAX.withShape(
|
||||
VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 2));
|
||||
PREF_SHORT_SPECIES =
|
||||
ShortVector.SPECIES_MAX.withShape(
|
||||
VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 1));
|
||||
} else {
|
||||
PREF_BYTE_SPECIES = null;
|
||||
PREF_SHORT_SPECIES = null;
|
||||
}
|
||||
}
|
||||
|
||||
VectorUtilPanamaProvider() {
|
||||
if (INT_SPECIES_PREF_BIT_SIZE < 128) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Vector bit size is less than 128: " + INT_SPECIES_PREF_BIT_SIZE);
|
||||
}
|
||||
var log = Logger.getLogger(getClass().getName());
|
||||
log.info(
|
||||
"Java vector incubator API enabled; uses preferredBitSize=" + INT_SPECIES_PREF_BIT_SIZE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float dotProduct(float[] a, float[] b) {
|
||||
int i = 0;
|
||||
float res = 0;
|
||||
// if the array size is large (> 2x platform vector size), its worth the overhead to vectorize
|
||||
if (a.length > 2 * PREF_FLOAT_SPECIES.length()) {
|
||||
// vector loop is unrolled 4x (4 accumulators in parallel)
|
||||
FloatVector acc1 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector acc2 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector acc3 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector acc4 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
int upperBound = PREF_FLOAT_SPECIES.loopBound(a.length - 3 * PREF_FLOAT_SPECIES.length());
|
||||
for (; i < upperBound; i += 4 * PREF_FLOAT_SPECIES.length()) {
|
||||
FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i);
|
||||
acc1 = acc1.add(va.mul(vb));
|
||||
FloatVector vc =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + PREF_FLOAT_SPECIES.length());
|
||||
FloatVector vd =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + PREF_FLOAT_SPECIES.length());
|
||||
acc2 = acc2.add(vc.mul(vd));
|
||||
FloatVector ve =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 2 * PREF_FLOAT_SPECIES.length());
|
||||
FloatVector vf =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 2 * PREF_FLOAT_SPECIES.length());
|
||||
acc3 = acc3.add(ve.mul(vf));
|
||||
FloatVector vg =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 3 * PREF_FLOAT_SPECIES.length());
|
||||
FloatVector vh =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 3 * PREF_FLOAT_SPECIES.length());
|
||||
acc4 = acc4.add(vg.mul(vh));
|
||||
}
|
||||
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
|
||||
upperBound = PREF_FLOAT_SPECIES.loopBound(a.length);
|
||||
for (; i < upperBound; i += PREF_FLOAT_SPECIES.length()) {
|
||||
FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i);
|
||||
acc1 = acc1.add(va.mul(vb));
|
||||
}
|
||||
// reduce
|
||||
FloatVector res1 = acc1.add(acc2);
|
||||
FloatVector res2 = acc3.add(acc4);
|
||||
res += res1.add(res2).reduceLanes(VectorOperators.ADD);
|
||||
}
|
||||
|
||||
for (; i < a.length; i++) {
|
||||
res += b[i] * a[i];
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float cosine(float[] a, float[] b) {
|
||||
int i = 0;
|
||||
float sum = 0;
|
||||
float norm1 = 0;
|
||||
float norm2 = 0;
|
||||
// if the array size is large (> 2x platform vector size), its worth the overhead to vectorize
|
||||
if (a.length > 2 * PREF_FLOAT_SPECIES.length()) {
|
||||
// vector loop is unrolled 4x (4 accumulators in parallel)
|
||||
FloatVector sum1 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector sum2 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector sum3 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector sum4 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector norm1_1 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector norm1_2 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector norm1_3 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector norm1_4 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector norm2_1 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector norm2_2 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector norm2_3 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector norm2_4 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
int upperBound = PREF_FLOAT_SPECIES.loopBound(a.length - 3 * PREF_FLOAT_SPECIES.length());
|
||||
for (; i < upperBound; i += 4 * PREF_FLOAT_SPECIES.length()) {
|
||||
FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i);
|
||||
sum1 = sum1.add(va.mul(vb));
|
||||
norm1_1 = norm1_1.add(va.mul(va));
|
||||
norm2_1 = norm2_1.add(vb.mul(vb));
|
||||
FloatVector vc =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + PREF_FLOAT_SPECIES.length());
|
||||
FloatVector vd =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + PREF_FLOAT_SPECIES.length());
|
||||
sum2 = sum2.add(vc.mul(vd));
|
||||
norm1_2 = norm1_2.add(vc.mul(vc));
|
||||
norm2_2 = norm2_2.add(vd.mul(vd));
|
||||
FloatVector ve =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 2 * PREF_FLOAT_SPECIES.length());
|
||||
FloatVector vf =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 2 * PREF_FLOAT_SPECIES.length());
|
||||
sum3 = sum3.add(ve.mul(vf));
|
||||
norm1_3 = norm1_3.add(ve.mul(ve));
|
||||
norm2_3 = norm2_3.add(vf.mul(vf));
|
||||
FloatVector vg =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 3 * PREF_FLOAT_SPECIES.length());
|
||||
FloatVector vh =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 3 * PREF_FLOAT_SPECIES.length());
|
||||
sum4 = sum4.add(vg.mul(vh));
|
||||
norm1_4 = norm1_4.add(vg.mul(vg));
|
||||
norm2_4 = norm2_4.add(vh.mul(vh));
|
||||
}
|
||||
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
|
||||
upperBound = PREF_FLOAT_SPECIES.loopBound(a.length);
|
||||
for (; i < upperBound; i += PREF_FLOAT_SPECIES.length()) {
|
||||
FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i);
|
||||
sum1 = sum1.add(va.mul(vb));
|
||||
norm1_1 = norm1_1.add(va.mul(va));
|
||||
norm2_1 = norm2_1.add(vb.mul(vb));
|
||||
}
|
||||
// reduce
|
||||
FloatVector sumres1 = sum1.add(sum2);
|
||||
FloatVector sumres2 = sum3.add(sum4);
|
||||
FloatVector norm1res1 = norm1_1.add(norm1_2);
|
||||
FloatVector norm1res2 = norm1_3.add(norm1_4);
|
||||
FloatVector norm2res1 = norm2_1.add(norm2_2);
|
||||
FloatVector norm2res2 = norm2_3.add(norm2_4);
|
||||
sum += sumres1.add(sumres2).reduceLanes(VectorOperators.ADD);
|
||||
norm1 += norm1res1.add(norm1res2).reduceLanes(VectorOperators.ADD);
|
||||
norm2 += norm2res1.add(norm2res2).reduceLanes(VectorOperators.ADD);
|
||||
}
|
||||
|
||||
for (; i < a.length; i++) {
|
||||
float elem1 = a[i];
|
||||
float elem2 = b[i];
|
||||
sum += elem1 * elem2;
|
||||
norm1 += elem1 * elem1;
|
||||
norm2 += elem2 * elem2;
|
||||
}
|
||||
return (float) (sum / Math.sqrt(norm1 * norm2));
|
||||
}
|
||||
|
||||
@Override
|
||||
public float squareDistance(float[] a, float[] b) {
|
||||
int i = 0;
|
||||
float res = 0;
|
||||
// if the array size is large (> 2x platform vector size), its worth the overhead to vectorize
|
||||
if (a.length > 2 * PREF_FLOAT_SPECIES.length()) {
|
||||
// vector loop is unrolled 4x (4 accumulators in parallel)
|
||||
FloatVector acc1 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector acc2 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector acc3 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
FloatVector acc4 = FloatVector.zero(PREF_FLOAT_SPECIES);
|
||||
int upperBound = PREF_FLOAT_SPECIES.loopBound(a.length - 3 * PREF_FLOAT_SPECIES.length());
|
||||
for (; i < upperBound; i += 4 * PREF_FLOAT_SPECIES.length()) {
|
||||
FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i);
|
||||
FloatVector diff1 = va.sub(vb);
|
||||
acc1 = acc1.add(diff1.mul(diff1));
|
||||
FloatVector vc =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + PREF_FLOAT_SPECIES.length());
|
||||
FloatVector vd =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + PREF_FLOAT_SPECIES.length());
|
||||
FloatVector diff2 = vc.sub(vd);
|
||||
acc2 = acc2.add(diff2.mul(diff2));
|
||||
FloatVector ve =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 2 * PREF_FLOAT_SPECIES.length());
|
||||
FloatVector vf =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 2 * PREF_FLOAT_SPECIES.length());
|
||||
FloatVector diff3 = ve.sub(vf);
|
||||
acc3 = acc3.add(diff3.mul(diff3));
|
||||
FloatVector vg =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 3 * PREF_FLOAT_SPECIES.length());
|
||||
FloatVector vh =
|
||||
FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 3 * PREF_FLOAT_SPECIES.length());
|
||||
FloatVector diff4 = vg.sub(vh);
|
||||
acc4 = acc4.add(diff4.mul(diff4));
|
||||
}
|
||||
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
|
||||
upperBound = PREF_FLOAT_SPECIES.loopBound(a.length);
|
||||
for (; i < upperBound; i += PREF_FLOAT_SPECIES.length()) {
|
||||
FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i);
|
||||
FloatVector diff = va.sub(vb);
|
||||
acc1 = acc1.add(diff.mul(diff));
|
||||
}
|
||||
// reduce
|
||||
FloatVector res1 = acc1.add(acc2);
|
||||
FloatVector res2 = acc3.add(acc4);
|
||||
res += res1.add(res2).reduceLanes(VectorOperators.ADD);
|
||||
}
|
||||
|
||||
for (; i < a.length; i++) {
|
||||
float diff = a[i] - b[i];
|
||||
res += diff * diff;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// Binary functions, these all follow a general pattern like this:
|
||||
//
|
||||
// short intermediate = a * b;
|
||||
// int accumulator = accumulator + intermediate;
|
||||
//
|
||||
// 256 or 512 bit vectors can process 64 or 128 bits at a time, respectively
|
||||
// intermediate results use 128 or 256 bit vectors, respectively
|
||||
// final accumulator uses 256 or 512 bit vectors, respectively
|
||||
//
|
||||
// We also support 128 bit vectors, using two 128 bit accumulators.
|
||||
// This is slower but still faster than not vectorizing at all.
|
||||
|
||||
@Override
|
||||
public int dotProduct(byte[] a, byte[] b) {
|
||||
int i = 0;
|
||||
int res = 0;
|
||||
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
|
||||
// vectors (256-bit on intel to dodge performance landmines)
|
||||
if (a.length >= 16 && IS_AMD64_WITHOUT_AVX2 == false) {
|
||||
// compute vectorized dot product consistent with VPDPBUSD instruction
|
||||
if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
|
||||
// optimized 256/512 bit implementation, processes 8/16 bytes at a time
|
||||
int upperBound = PREF_BYTE_SPECIES.loopBound(a.length);
|
||||
IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
for (; i < upperBound; i += PREF_BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(PREF_BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(PREF_BYTE_SPECIES, b, i);
|
||||
Vector<Short> va16 = va8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0);
|
||||
Vector<Short> vb16 = vb8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0);
|
||||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
Vector<Integer> prod32 =
|
||||
prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
acc = acc.add(prod32);
|
||||
}
|
||||
// reduce
|
||||
res += acc.reduceLanes(VectorOperators.ADD);
|
||||
} else {
|
||||
// 128-bit implementation, which must "split up" vectors due to widening conversions
|
||||
int upperBound = ByteVector.SPECIES_64.loopBound(a.length);
|
||||
IntVector acc1 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector acc2 = IntVector.zero(IntVector.SPECIES_128);
|
||||
for (; i < upperBound; i += ByteVector.SPECIES_64.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
|
||||
// expand each byte vector into short vector and multiply
|
||||
Vector<Short> va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
|
||||
Vector<Short> vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
|
||||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
// split each short vector into two int vectors and add
|
||||
Vector<Integer> prod32_1 =
|
||||
prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
|
||||
Vector<Integer> prod32_2 =
|
||||
prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
|
||||
acc1 = acc1.add(prod32_1);
|
||||
acc2 = acc2.add(prod32_2);
|
||||
}
|
||||
// reduce
|
||||
res += acc1.add(acc2).reduceLanes(VectorOperators.ADD);
|
||||
}
|
||||
}
|
||||
|
||||
for (; i < a.length; i++) {
|
||||
res += b[i] * a[i];
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float cosine(byte[] a, byte[] b) {
|
||||
int i = 0;
|
||||
int sum = 0;
|
||||
int norm1 = 0;
|
||||
int norm2 = 0;
|
||||
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
|
||||
// vectors (256-bit on intel to dodge performance landmines)
|
||||
if (a.length >= 16 && IS_AMD64_WITHOUT_AVX2 == false) {
|
||||
if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
|
||||
// optimized 256/512 bit implementation, processes 8/16 bytes at a time
|
||||
int upperBound = PREF_BYTE_SPECIES.loopBound(a.length);
|
||||
IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
for (; i < upperBound; i += PREF_BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(PREF_BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(PREF_BYTE_SPECIES, b, i);
|
||||
Vector<Short> va16 = va8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0);
|
||||
Vector<Short> vb16 = vb8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0);
|
||||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
Vector<Short> norm1_16 = va16.mul(va16);
|
||||
Vector<Short> norm2_16 = vb16.mul(vb16);
|
||||
Vector<Integer> prod32 =
|
||||
prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> norm1_32 =
|
||||
norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> norm2_32 =
|
||||
norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
accSum = accSum.add(prod32);
|
||||
accNorm1 = accNorm1.add(norm1_32);
|
||||
accNorm2 = accNorm2.add(norm2_32);
|
||||
}
|
||||
// reduce
|
||||
sum += accSum.reduceLanes(VectorOperators.ADD);
|
||||
norm1 += accNorm1.reduceLanes(VectorOperators.ADD);
|
||||
norm2 += accNorm2.reduceLanes(VectorOperators.ADD);
|
||||
} else {
|
||||
// 128-bit implementation, which must "split up" vectors due to widening conversions
|
||||
int upperBound = ByteVector.SPECIES_64.loopBound(a.length);
|
||||
IntVector accSum1 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accSum2 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm1_1 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm1_2 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm2_1 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm2_2 = IntVector.zero(IntVector.SPECIES_128);
|
||||
for (; i < upperBound; i += ByteVector.SPECIES_64.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
|
||||
// expand each byte vector into short vector and perform multiplications
|
||||
Vector<Short> va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
|
||||
Vector<Short> vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
|
||||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
Vector<Short> norm1_16 = va16.mul(va16);
|
||||
Vector<Short> norm2_16 = vb16.mul(vb16);
|
||||
// split each short vector into two int vectors and add
|
||||
Vector<Integer> prod32_1 =
|
||||
prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
|
||||
Vector<Integer> prod32_2 =
|
||||
prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
|
||||
Vector<Integer> norm1_32_1 =
|
||||
norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
|
||||
Vector<Integer> norm1_32_2 =
|
||||
norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
|
||||
Vector<Integer> norm2_32_1 =
|
||||
norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
|
||||
Vector<Integer> norm2_32_2 =
|
||||
norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
|
||||
accSum1 = accSum1.add(prod32_1);
|
||||
accSum2 = accSum2.add(prod32_2);
|
||||
accNorm1_1 = accNorm1_1.add(norm1_32_1);
|
||||
accNorm1_2 = accNorm1_2.add(norm1_32_2);
|
||||
accNorm2_1 = accNorm2_1.add(norm2_32_1);
|
||||
accNorm2_2 = accNorm2_2.add(norm2_32_2);
|
||||
}
|
||||
// reduce
|
||||
sum += accSum1.add(accSum2).reduceLanes(VectorOperators.ADD);
|
||||
norm1 += accNorm1_1.add(accNorm1_2).reduceLanes(VectorOperators.ADD);
|
||||
norm2 += accNorm2_1.add(accNorm2_2).reduceLanes(VectorOperators.ADD);
|
||||
}
|
||||
}
|
||||
|
||||
for (; i < a.length; i++) {
|
||||
byte elem1 = a[i];
|
||||
byte elem2 = b[i];
|
||||
sum += elem1 * elem2;
|
||||
norm1 += elem1 * elem1;
|
||||
norm2 += elem2 * elem2;
|
||||
}
|
||||
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int squareDistance(byte[] a, byte[] b) {
|
||||
int i = 0;
|
||||
int res = 0;
|
||||
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
|
||||
// vectors (256-bit on intel to dodge performance landmines)
|
||||
if (a.length >= 16 && IS_AMD64_WITHOUT_AVX2 == false) {
|
||||
if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
|
||||
// optimized 256/512 bit implementation, processes 8/16 bytes at a time
|
||||
int upperBound = PREF_BYTE_SPECIES.loopBound(a.length);
|
||||
IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
for (; i < upperBound; i += PREF_BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(PREF_BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(PREF_BYTE_SPECIES, b, i);
|
||||
Vector<Short> va16 = va8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0);
|
||||
Vector<Short> vb16 = vb8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0);
|
||||
Vector<Short> diff16 = va16.sub(vb16);
|
||||
Vector<Integer> diff32 =
|
||||
diff16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
acc = acc.add(diff32.mul(diff32));
|
||||
}
|
||||
// reduce
|
||||
res += acc.reduceLanes(VectorOperators.ADD);
|
||||
} else {
|
||||
// 128-bit implementation, which must "split up" vectors due to widening conversions
|
||||
int upperBound = ByteVector.SPECIES_64.loopBound(a.length);
|
||||
IntVector acc1 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector acc2 = IntVector.zero(IntVector.SPECIES_128);
|
||||
for (; i < upperBound; i += ByteVector.SPECIES_64.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
|
||||
// expand each byte vector into short vector and subtract
|
||||
Vector<Short> va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
|
||||
Vector<Short> vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
|
||||
Vector<Short> diff16 = va16.sub(vb16);
|
||||
// split each short vector into two int vectors, square, and add
|
||||
Vector<Integer> diff32_1 =
|
||||
diff16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
|
||||
Vector<Integer> diff32_2 =
|
||||
diff16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
|
||||
acc1 = acc1.add(diff32_1.mul(diff32_1));
|
||||
acc2 = acc2.add(diff32_2.mul(diff32_2));
|
||||
}
|
||||
// reduce
|
||||
res += acc1.add(acc2).reduceLanes(VectorOperators.ADD);
|
||||
}
|
||||
}
|
||||
|
||||
for (; i < a.length; i++) {
|
||||
int diff = a[i] - b[i];
|
||||
res += diff * diff;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.util;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||
import java.util.function.ToDoubleFunction;
|
||||
import java.util.function.ToIntFunction;
|
||||
import java.util.stream.IntStream;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.junit.BeforeClass;
|
||||
|
||||
public class TestVectorUtilProviders extends LuceneTestCase {
|
||||
|
||||
private static final double DELTA = 1e-4;
|
||||
private static final VectorUtilProvider LUCENE_PROVIDER = new VectorUtilDefaultProvider();
|
||||
private static final VectorUtilProvider JDK_PROVIDER = VectorUtil.PROVIDER;
|
||||
|
||||
private static final int[] VECTOR_SIZES = {
|
||||
1, 4, 6, 8, 13, 16, 25, 32, 64, 100, 128, 207, 256, 300, 512, 702, 1024
|
||||
};
|
||||
|
||||
private final int size;
|
||||
|
||||
public TestVectorUtilProviders(int size) {
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
@ParametersFactory
|
||||
public static Iterable<Object[]> parametersFactory() {
|
||||
return () -> IntStream.of(VECTOR_SIZES).boxed().map(i -> new Object[] {i}).iterator();
|
||||
}
|
||||
|
||||
@BeforeClass
|
||||
public static void beforeClass() throws Exception {
|
||||
assumeFalse(
|
||||
"Test only works when JDK's vector incubator module is enabled.",
|
||||
JDK_PROVIDER instanceof VectorUtilDefaultProvider);
|
||||
}
|
||||
|
||||
public void testFloatVectors() {
|
||||
var a = new float[size];
|
||||
var b = new float[size];
|
||||
for (int i = 0; i < size; ++i) {
|
||||
a[i] = random().nextFloat();
|
||||
b[i] = random().nextFloat();
|
||||
}
|
||||
assertFloatReturningProviders(p -> p.dotProduct(a, b));
|
||||
assertFloatReturningProviders(p -> p.squareDistance(a, b));
|
||||
assertFloatReturningProviders(p -> p.cosine(a, b));
|
||||
}
|
||||
|
||||
public void testBinaryVectors() {
|
||||
var a = new byte[size];
|
||||
var b = new byte[size];
|
||||
random().nextBytes(a);
|
||||
random().nextBytes(b);
|
||||
assertIntReturningProviders(p -> p.dotProduct(a, b));
|
||||
assertIntReturningProviders(p -> p.squareDistance(a, b));
|
||||
assertFloatReturningProviders(p -> p.cosine(a, b));
|
||||
}
|
||||
|
||||
private void assertFloatReturningProviders(ToDoubleFunction<VectorUtilProvider> func) {
|
||||
assertEquals(func.applyAsDouble(LUCENE_PROVIDER), func.applyAsDouble(JDK_PROVIDER), DELTA);
|
||||
}
|
||||
|
||||
private void assertIntReturningProviders(ToIntFunction<VectorUtilProvider> func) {
|
||||
assertEquals(func.applyAsInt(LUCENE_PROVIDER), func.applyAsInt(JDK_PROVIDER));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue