diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index f49685ffd25..49a4da82b53 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -352,6 +352,9 @@ Optimizations * GITHUB#13581: OnHeapHnswGraph no longer allocates a lock for every graph node (Mike Sokolov) +* GITHUB#13636: Optimizations to the decoding logic of blocks of postings. + (Adrien Grand, Uwe Schindler) + Changes in runtime behavior --------------------- diff --git a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/PostingIndexInputBenchmark.java b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/PostingIndexInputBenchmark.java new file mode 100644 index 00000000000..3804af0b167 --- /dev/null +++ b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/PostingIndexInputBenchmark.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.benchmark.jmh; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.apache.lucene.codecs.lucene912.ForUtil; +import org.apache.lucene.codecs.lucene912.PostingIndexInput; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.IOUtils; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@State(Scope.Benchmark) +@Warmup(iterations = 5, time = 1) +@Measurement(iterations = 5, time = 1) +@Fork( + value = 3, + jvmArgsAppend = {"-Xmx1g", "-Xms1g", "-XX:+AlwaysPreTouch"}) +public class PostingIndexInputBenchmark { + + private Path path; + private Directory dir; + private IndexInput in; + private PostingIndexInput postingIn; + private final ForUtil forUtil = new ForUtil(); + private final long[] values = new long[128]; + + @Param({"5", "6", "7", "8", "9", "10"}) + public int bpv; + + @Setup(Level.Trial) + public void setup() throws Exception { + path = Files.createTempDirectory("forUtil"); + dir = MMapDirectory.open(path); + try (IndexOutput out = dir.createOutput("docs", IOContext.DEFAULT)) { + Random r = new Random(0); + // Write enough random data to not reach EOF while decoding + for (int i = 0; i < 100; ++i) { + out.writeLong(r.nextLong()); + } + } + in = dir.openInput("docs", IOContext.DEFAULT); + postingIn = new PostingIndexInput(in, forUtil); + } + + @TearDown(Level.Trial) + public void tearDown() throws Exception { + if (dir != null) { + dir.deleteFile("docs"); + } + IOUtils.close(in, dir); + in = null; + dir = null; + Files.deleteIfExists(path); + } + + @Benchmark + public void decode(Blackhole bh) throws IOException { + in.seek(3); // random unaligned offset + postingIn.decode(bpv, values); + bh.consume(values); + } + + @Benchmark + public void decodeAndPrefixSum(Blackhole bh) throws IOException { + in.seek(3); // random unaligned offset + postingIn.decodeAndPrefixSum(bpv, 100, values); + bh.consume(values); + } +} diff --git a/lucene/core/src/generated/checksums/generateForUtil.json b/lucene/core/src/generated/checksums/generateForUtil.json index 752285f4d7f..e147f2c62f7 100644 --- a/lucene/core/src/generated/checksums/generateForUtil.json +++ b/lucene/core/src/generated/checksums/generateForUtil.json @@ -1,4 +1,4 @@ { - "lucene/core/src/java/org/apache/lucene/codecs/lucene912/ForUtil.java": "5ff856e80cab30f9e5704aa89f3197f017d07624", - "lucene/core/src/java/org/apache/lucene/codecs/lucene912/gen_ForUtil.py": "3ccf92b3ddbff6340a13e8a55090bfb900dc7be2" + "lucene/core/src/java/org/apache/lucene/codecs/lucene912/ForUtil.java": "cc7d40997e2d6500b79c19ff47461ed6e89d2268", + "lucene/core/src/java/org/apache/lucene/codecs/lucene912/gen_ForUtil.py": "ba029f2e374e66c6cf315b2c93f4efa6944dfbb8" } \ No newline at end of file diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene912/ForDeltaUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene912/ForDeltaUtil.java index 8b9aedcfb2b..b53c18fa3f8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene912/ForDeltaUtil.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene912/ForDeltaUtil.java @@ -22,7 +22,7 @@ import org.apache.lucene.store.DataOutput; import org.apache.lucene.util.packed.PackedInts; /** Utility class to encode/decode increasing sequences of 128 integers. */ -public class ForDeltaUtil { +final class ForDeltaUtil { // IDENTITY_PLUS_ONE[i] == i+1 private static final long[] IDENTITY_PLUS_ONE = new long[ForUtil.BLOCK_SIZE]; @@ -67,12 +67,12 @@ public class ForDeltaUtil { } /** Decode deltas, compute the prefix sum and add {@code base} to all decoded longs. */ - void decodeAndPrefixSum(DataInput in, long base, long[] longs) throws IOException { - final int bitsPerValue = Byte.toUnsignedInt(in.readByte()); + void decodeAndPrefixSum(PostingIndexInput in, long base, long[] longs) throws IOException { + final int bitsPerValue = Byte.toUnsignedInt(in.in.readByte()); if (bitsPerValue == 0) { prefixSumOfOnes(longs, base); } else { - forUtil.decodeAndPrefixSum(bitsPerValue, in, base, longs); + in.decodeAndPrefixSum(bitsPerValue, base, longs); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene912/ForUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene912/ForUtil.java index 63ee7baaf10..4408a8a3e57 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene912/ForUtil.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene912/ForUtil.java @@ -19,17 +19,18 @@ package org.apache.lucene.codecs.lucene912; import java.io.IOException; -import org.apache.lucene.store.DataInput; +import org.apache.lucene.internal.vectorization.PostingDecodingUtil; import org.apache.lucene.store.DataOutput; +import org.apache.lucene.store.IndexInput; -// Inspired from https://fulmicoton.com/posts/bitpacking/ -// Encodes multiple integers in a long to get SIMD-like speedups. -// If bitsPerValue <= 8 then we pack 8 ints per long -// else if bitsPerValue <= 16 we pack 4 ints per long -// else we pack 2 ints per long -final class ForUtil { +/** + * Inspired from https://fulmicoton.com/posts/bitpacking/ Encodes multiple integers in a long to get + * SIMD-like speedups. If bitsPerValue <= 8 then we pack 8 ints per long else if bitsPerValue + * <= 16 we pack 4 ints per long else we pack 2 ints per long + */ +public final class ForUtil { - static final int BLOCK_SIZE = 128; + public static final int BLOCK_SIZE = 128; private static final int BLOCK_SIZE_LOG2 = 7; private static long expandMask32(long mask32) { @@ -300,13 +301,14 @@ final class ForUtil { return bitsPerValue << (BLOCK_SIZE_LOG2 - 3); } - private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) + private static void decodeSlow( + int bitsPerValue, IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) throws IOException { final int numLongs = bitsPerValue << 1; - in.readLongs(tmp, 0, numLongs); final long mask = MASKS32[bitsPerValue]; - int longsIdx = 0; - int shift = 32 - bitsPerValue; + pdu.splitLongs(numLongs, longs, 32 - bitsPerValue, mask, tmp, 0, -1L); + int longsIdx = numLongs; + int shift = 32 - 2 * bitsPerValue; for (; shift >= 0; shift -= bitsPerValue) { shiftLongs(tmp, numLongs, longs, longsIdx, shift, mask); longsIdx += numLongs; @@ -342,6 +344,13 @@ final class ForUtil { } } + /** Likewise, but for a simple mask. */ + private static void maskLongs(long[] a, int count, long[] b, int bi, long mask) { + for (int i = 0; i < count; ++i) { + b[bi + i] = a[i] & mask; + } + } + private static final long[] MASKS8 = new long[8]; private static final long[] MASKS16 = new long[16]; private static final long[] MASKS32 = new long[32]; @@ -406,279 +415,280 @@ final class ForUtil { private static final long MASK32_24 = MASKS32[24]; /** Decode 128 integers into {@code longs}. */ - void decode(int bitsPerValue, DataInput in, long[] longs) throws IOException { + void decode(int bitsPerValue, IndexInput in, PostingDecodingUtil pdu, long[] longs) + throws IOException { switch (bitsPerValue) { case 1: - decode1(in, tmp, longs); + decode1(in, pdu, tmp, longs); expand8(longs); break; case 2: - decode2(in, tmp, longs); + decode2(in, pdu, tmp, longs); expand8(longs); break; case 3: - decode3(in, tmp, longs); + decode3(in, pdu, tmp, longs); expand8(longs); break; case 4: - decode4(in, tmp, longs); + decode4(in, pdu, tmp, longs); expand8(longs); break; case 5: - decode5(in, tmp, longs); + decode5(in, pdu, tmp, longs); expand8(longs); break; case 6: - decode6(in, tmp, longs); + decode6(in, pdu, tmp, longs); expand8(longs); break; case 7: - decode7(in, tmp, longs); + decode7(in, pdu, tmp, longs); expand8(longs); break; case 8: - decode8(in, tmp, longs); + decode8(in, pdu, tmp, longs); expand8(longs); break; case 9: - decode9(in, tmp, longs); + decode9(in, pdu, tmp, longs); expand16(longs); break; case 10: - decode10(in, tmp, longs); + decode10(in, pdu, tmp, longs); expand16(longs); break; case 11: - decode11(in, tmp, longs); + decode11(in, pdu, tmp, longs); expand16(longs); break; case 12: - decode12(in, tmp, longs); + decode12(in, pdu, tmp, longs); expand16(longs); break; case 13: - decode13(in, tmp, longs); + decode13(in, pdu, tmp, longs); expand16(longs); break; case 14: - decode14(in, tmp, longs); + decode14(in, pdu, tmp, longs); expand16(longs); break; case 15: - decode15(in, tmp, longs); + decode15(in, pdu, tmp, longs); expand16(longs); break; case 16: - decode16(in, tmp, longs); + decode16(in, pdu, tmp, longs); expand16(longs); break; case 17: - decode17(in, tmp, longs); + decode17(in, pdu, tmp, longs); expand32(longs); break; case 18: - decode18(in, tmp, longs); + decode18(in, pdu, tmp, longs); expand32(longs); break; case 19: - decode19(in, tmp, longs); + decode19(in, pdu, tmp, longs); expand32(longs); break; case 20: - decode20(in, tmp, longs); + decode20(in, pdu, tmp, longs); expand32(longs); break; case 21: - decode21(in, tmp, longs); + decode21(in, pdu, tmp, longs); expand32(longs); break; case 22: - decode22(in, tmp, longs); + decode22(in, pdu, tmp, longs); expand32(longs); break; case 23: - decode23(in, tmp, longs); + decode23(in, pdu, tmp, longs); expand32(longs); break; case 24: - decode24(in, tmp, longs); + decode24(in, pdu, tmp, longs); expand32(longs); break; default: - decodeSlow(bitsPerValue, in, tmp, longs); + decodeSlow(bitsPerValue, in, pdu, tmp, longs); expand32(longs); break; } } /** Delta-decode 128 integers into {@code longs}. */ - void decodeAndPrefixSum(int bitsPerValue, DataInput in, long base, long[] longs) + void decodeAndPrefixSum( + int bitsPerValue, IndexInput in, PostingDecodingUtil pdu, long base, long[] longs) throws IOException { switch (bitsPerValue) { case 1: - decode1(in, tmp, longs); + decode1(in, pdu, tmp, longs); prefixSum8(longs, base); break; case 2: - decode2(in, tmp, longs); + decode2(in, pdu, tmp, longs); prefixSum8(longs, base); break; case 3: - decode3(in, tmp, longs); + decode3(in, pdu, tmp, longs); prefixSum8(longs, base); break; case 4: - decode4(in, tmp, longs); + decode4(in, pdu, tmp, longs); prefixSum8(longs, base); break; case 5: - decode5(in, tmp, longs); + decode5(in, pdu, tmp, longs); prefixSum8(longs, base); break; case 6: - decode6(in, tmp, longs); + decode6(in, pdu, tmp, longs); prefixSum8(longs, base); break; case 7: - decode7(in, tmp, longs); + decode7(in, pdu, tmp, longs); prefixSum8(longs, base); break; case 8: - decode8(in, tmp, longs); + decode8(in, pdu, tmp, longs); prefixSum8(longs, base); break; case 9: - decode9(in, tmp, longs); + decode9(in, pdu, tmp, longs); prefixSum16(longs, base); break; case 10: - decode10(in, tmp, longs); + decode10(in, pdu, tmp, longs); prefixSum16(longs, base); break; case 11: - decode11(in, tmp, longs); + decode11(in, pdu, tmp, longs); prefixSum16(longs, base); break; case 12: - decode12(in, tmp, longs); + decode12(in, pdu, tmp, longs); prefixSum16(longs, base); break; case 13: - decode13(in, tmp, longs); + decode13(in, pdu, tmp, longs); prefixSum16(longs, base); break; case 14: - decode14(in, tmp, longs); + decode14(in, pdu, tmp, longs); prefixSum16(longs, base); break; case 15: - decode15(in, tmp, longs); + decode15(in, pdu, tmp, longs); prefixSum16(longs, base); break; case 16: - decode16(in, tmp, longs); + decode16(in, pdu, tmp, longs); prefixSum16(longs, base); break; case 17: - decode17(in, tmp, longs); + decode17(in, pdu, tmp, longs); prefixSum32(longs, base); break; case 18: - decode18(in, tmp, longs); + decode18(in, pdu, tmp, longs); prefixSum32(longs, base); break; case 19: - decode19(in, tmp, longs); + decode19(in, pdu, tmp, longs); prefixSum32(longs, base); break; case 20: - decode20(in, tmp, longs); + decode20(in, pdu, tmp, longs); prefixSum32(longs, base); break; case 21: - decode21(in, tmp, longs); + decode21(in, pdu, tmp, longs); prefixSum32(longs, base); break; case 22: - decode22(in, tmp, longs); + decode22(in, pdu, tmp, longs); prefixSum32(longs, base); break; case 23: - decode23(in, tmp, longs); + decode23(in, pdu, tmp, longs); prefixSum32(longs, base); break; case 24: - decode24(in, tmp, longs); + decode24(in, pdu, tmp, longs); prefixSum32(longs, base); break; default: - decodeSlow(bitsPerValue, in, tmp, longs); + decodeSlow(bitsPerValue, in, pdu, tmp, longs); prefixSum32(longs, base); break; } } - private static void decode1(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 2); - shiftLongs(tmp, 2, longs, 0, 7, MASK8_1); + private static void decode1(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(2, longs, 7, MASK8_1, tmp, 0, MASK8_7); shiftLongs(tmp, 2, longs, 2, 6, MASK8_1); shiftLongs(tmp, 2, longs, 4, 5, MASK8_1); shiftLongs(tmp, 2, longs, 6, 4, MASK8_1); shiftLongs(tmp, 2, longs, 8, 3, MASK8_1); shiftLongs(tmp, 2, longs, 10, 2, MASK8_1); shiftLongs(tmp, 2, longs, 12, 1, MASK8_1); - shiftLongs(tmp, 2, longs, 14, 0, MASK8_1); + maskLongs(tmp, 2, longs, 14, MASK8_1); } - private static void decode2(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 4); - shiftLongs(tmp, 4, longs, 0, 6, MASK8_2); + private static void decode2(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(4, longs, 6, MASK8_2, tmp, 0, MASK8_6); shiftLongs(tmp, 4, longs, 4, 4, MASK8_2); shiftLongs(tmp, 4, longs, 8, 2, MASK8_2); - shiftLongs(tmp, 4, longs, 12, 0, MASK8_2); + maskLongs(tmp, 4, longs, 12, MASK8_2); } - private static void decode3(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 6); - shiftLongs(tmp, 6, longs, 0, 5, MASK8_3); + private static void decode3(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(6, longs, 5, MASK8_3, tmp, 0, MASK8_5); shiftLongs(tmp, 6, longs, 6, 2, MASK8_3); + maskLongs(tmp, 6, tmp, 0, MASK8_2); for (int iter = 0, tmpIdx = 0, longsIdx = 12; iter < 2; ++iter, tmpIdx += 3, longsIdx += 2) { - long l0 = (tmp[tmpIdx + 0] & MASK8_2) << 1; + long l0 = tmp[tmpIdx + 0] << 1; l0 |= (tmp[tmpIdx + 1] >>> 1) & MASK8_1; longs[longsIdx + 0] = l0; long l1 = (tmp[tmpIdx + 1] & MASK8_1) << 2; - l1 |= (tmp[tmpIdx + 2] & MASK8_2) << 0; + l1 |= tmp[tmpIdx + 2] << 0; longs[longsIdx + 1] = l1; } } - private static void decode4(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 8); - shiftLongs(tmp, 8, longs, 0, 4, MASK8_4); - shiftLongs(tmp, 8, longs, 8, 0, MASK8_4); + private static void decode4(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(8, longs, 4, MASK8_4, longs, 8, MASK8_4); } - private static void decode5(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 10); - shiftLongs(tmp, 10, longs, 0, 3, MASK8_5); + private static void decode5(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(10, longs, 3, MASK8_5, tmp, 0, MASK8_3); for (int iter = 0, tmpIdx = 0, longsIdx = 10; iter < 2; ++iter, tmpIdx += 5, longsIdx += 3) { - long l0 = (tmp[tmpIdx + 0] & MASK8_3) << 2; + long l0 = tmp[tmpIdx + 0] << 2; l0 |= (tmp[tmpIdx + 1] >>> 1) & MASK8_2; longs[longsIdx + 0] = l0; long l1 = (tmp[tmpIdx + 1] & MASK8_1) << 4; - l1 |= (tmp[tmpIdx + 2] & MASK8_3) << 1; + l1 |= tmp[tmpIdx + 2] << 1; l1 |= (tmp[tmpIdx + 3] >>> 2) & MASK8_1; longs[longsIdx + 1] = l1; long l2 = (tmp[tmpIdx + 3] & MASK8_2) << 3; - l2 |= (tmp[tmpIdx + 4] & MASK8_3) << 0; + l2 |= tmp[tmpIdx + 4] << 0; longs[longsIdx + 2] = l2; } } - private static void decode6(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 12); - shiftLongs(tmp, 12, longs, 0, 2, MASK8_6); - shiftLongs(tmp, 12, tmp, 0, 0, MASK8_2); + private static void decode6(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(12, longs, 2, MASK8_6, tmp, 0, MASK8_2); for (int iter = 0, tmpIdx = 0, longsIdx = 12; iter < 4; ++iter, tmpIdx += 3, longsIdx += 1) { long l0 = tmp[tmpIdx + 0] << 4; l0 |= tmp[tmpIdx + 1] << 2; @@ -687,10 +697,9 @@ final class ForUtil { } } - private static void decode7(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 14); - shiftLongs(tmp, 14, longs, 0, 1, MASK8_7); - shiftLongs(tmp, 14, tmp, 0, 0, MASK8_1); + private static void decode7(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(14, longs, 1, MASK8_7, tmp, 0, MASK8_1); for (int iter = 0, tmpIdx = 0, longsIdx = 14; iter < 2; ++iter, tmpIdx += 7, longsIdx += 1) { long l0 = tmp[tmpIdx + 0] << 6; l0 |= tmp[tmpIdx + 1] << 5; @@ -703,15 +712,16 @@ final class ForUtil { } } - private static void decode8(DataInput in, long[] tmp, long[] longs) throws IOException { + private static void decode8(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { in.readLongs(longs, 0, 16); } - private static void decode9(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 18); - shiftLongs(tmp, 18, longs, 0, 7, MASK16_9); + private static void decode9(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(18, longs, 7, MASK16_9, tmp, 0, MASK16_7); for (int iter = 0, tmpIdx = 0, longsIdx = 18; iter < 2; ++iter, tmpIdx += 9, longsIdx += 7) { - long l0 = (tmp[tmpIdx + 0] & MASK16_7) << 2; + long l0 = tmp[tmpIdx + 0] << 2; l0 |= (tmp[tmpIdx + 1] >>> 5) & MASK16_2; longs[longsIdx + 0] = l0; long l1 = (tmp[tmpIdx + 1] & MASK16_5) << 4; @@ -721,7 +731,7 @@ final class ForUtil { l2 |= (tmp[tmpIdx + 3] >>> 1) & MASK16_6; longs[longsIdx + 2] = l2; long l3 = (tmp[tmpIdx + 3] & MASK16_1) << 8; - l3 |= (tmp[tmpIdx + 4] & MASK16_7) << 1; + l3 |= tmp[tmpIdx + 4] << 1; l3 |= (tmp[tmpIdx + 5] >>> 6) & MASK16_1; longs[longsIdx + 3] = l3; long l4 = (tmp[tmpIdx + 5] & MASK16_6) << 3; @@ -731,59 +741,58 @@ final class ForUtil { l5 |= (tmp[tmpIdx + 7] >>> 2) & MASK16_5; longs[longsIdx + 5] = l5; long l6 = (tmp[tmpIdx + 7] & MASK16_2) << 7; - l6 |= (tmp[tmpIdx + 8] & MASK16_7) << 0; + l6 |= tmp[tmpIdx + 8] << 0; longs[longsIdx + 6] = l6; } } - private static void decode10(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 20); - shiftLongs(tmp, 20, longs, 0, 6, MASK16_10); + private static void decode10(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(20, longs, 6, MASK16_10, tmp, 0, MASK16_6); for (int iter = 0, tmpIdx = 0, longsIdx = 20; iter < 4; ++iter, tmpIdx += 5, longsIdx += 3) { - long l0 = (tmp[tmpIdx + 0] & MASK16_6) << 4; + long l0 = tmp[tmpIdx + 0] << 4; l0 |= (tmp[tmpIdx + 1] >>> 2) & MASK16_4; longs[longsIdx + 0] = l0; long l1 = (tmp[tmpIdx + 1] & MASK16_2) << 8; - l1 |= (tmp[tmpIdx + 2] & MASK16_6) << 2; + l1 |= tmp[tmpIdx + 2] << 2; l1 |= (tmp[tmpIdx + 3] >>> 4) & MASK16_2; longs[longsIdx + 1] = l1; long l2 = (tmp[tmpIdx + 3] & MASK16_4) << 6; - l2 |= (tmp[tmpIdx + 4] & MASK16_6) << 0; + l2 |= tmp[tmpIdx + 4] << 0; longs[longsIdx + 2] = l2; } } - private static void decode11(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 22); - shiftLongs(tmp, 22, longs, 0, 5, MASK16_11); + private static void decode11(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(22, longs, 5, MASK16_11, tmp, 0, MASK16_5); for (int iter = 0, tmpIdx = 0, longsIdx = 22; iter < 2; ++iter, tmpIdx += 11, longsIdx += 5) { - long l0 = (tmp[tmpIdx + 0] & MASK16_5) << 6; - l0 |= (tmp[tmpIdx + 1] & MASK16_5) << 1; + long l0 = tmp[tmpIdx + 0] << 6; + l0 |= tmp[tmpIdx + 1] << 1; l0 |= (tmp[tmpIdx + 2] >>> 4) & MASK16_1; longs[longsIdx + 0] = l0; long l1 = (tmp[tmpIdx + 2] & MASK16_4) << 7; - l1 |= (tmp[tmpIdx + 3] & MASK16_5) << 2; + l1 |= tmp[tmpIdx + 3] << 2; l1 |= (tmp[tmpIdx + 4] >>> 3) & MASK16_2; longs[longsIdx + 1] = l1; long l2 = (tmp[tmpIdx + 4] & MASK16_3) << 8; - l2 |= (tmp[tmpIdx + 5] & MASK16_5) << 3; + l2 |= tmp[tmpIdx + 5] << 3; l2 |= (tmp[tmpIdx + 6] >>> 2) & MASK16_3; longs[longsIdx + 2] = l2; long l3 = (tmp[tmpIdx + 6] & MASK16_2) << 9; - l3 |= (tmp[tmpIdx + 7] & MASK16_5) << 4; + l3 |= tmp[tmpIdx + 7] << 4; l3 |= (tmp[tmpIdx + 8] >>> 1) & MASK16_4; longs[longsIdx + 3] = l3; long l4 = (tmp[tmpIdx + 8] & MASK16_1) << 10; - l4 |= (tmp[tmpIdx + 9] & MASK16_5) << 5; - l4 |= (tmp[tmpIdx + 10] & MASK16_5) << 0; + l4 |= tmp[tmpIdx + 9] << 5; + l4 |= tmp[tmpIdx + 10] << 0; longs[longsIdx + 4] = l4; } } - private static void decode12(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 24); - shiftLongs(tmp, 24, longs, 0, 4, MASK16_12); - shiftLongs(tmp, 24, tmp, 0, 0, MASK16_4); + private static void decode12(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(24, longs, 4, MASK16_12, tmp, 0, MASK16_4); for (int iter = 0, tmpIdx = 0, longsIdx = 24; iter < 8; ++iter, tmpIdx += 3, longsIdx += 1) { long l0 = tmp[tmpIdx + 0] << 8; l0 |= tmp[tmpIdx + 1] << 4; @@ -792,35 +801,34 @@ final class ForUtil { } } - private static void decode13(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 26); - shiftLongs(tmp, 26, longs, 0, 3, MASK16_13); + private static void decode13(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(26, longs, 3, MASK16_13, tmp, 0, MASK16_3); for (int iter = 0, tmpIdx = 0, longsIdx = 26; iter < 2; ++iter, tmpIdx += 13, longsIdx += 3) { - long l0 = (tmp[tmpIdx + 0] & MASK16_3) << 10; - l0 |= (tmp[tmpIdx + 1] & MASK16_3) << 7; - l0 |= (tmp[tmpIdx + 2] & MASK16_3) << 4; - l0 |= (tmp[tmpIdx + 3] & MASK16_3) << 1; + long l0 = tmp[tmpIdx + 0] << 10; + l0 |= tmp[tmpIdx + 1] << 7; + l0 |= tmp[tmpIdx + 2] << 4; + l0 |= tmp[tmpIdx + 3] << 1; l0 |= (tmp[tmpIdx + 4] >>> 2) & MASK16_1; longs[longsIdx + 0] = l0; long l1 = (tmp[tmpIdx + 4] & MASK16_2) << 11; - l1 |= (tmp[tmpIdx + 5] & MASK16_3) << 8; - l1 |= (tmp[tmpIdx + 6] & MASK16_3) << 5; - l1 |= (tmp[tmpIdx + 7] & MASK16_3) << 2; + l1 |= tmp[tmpIdx + 5] << 8; + l1 |= tmp[tmpIdx + 6] << 5; + l1 |= tmp[tmpIdx + 7] << 2; l1 |= (tmp[tmpIdx + 8] >>> 1) & MASK16_2; longs[longsIdx + 1] = l1; long l2 = (tmp[tmpIdx + 8] & MASK16_1) << 12; - l2 |= (tmp[tmpIdx + 9] & MASK16_3) << 9; - l2 |= (tmp[tmpIdx + 10] & MASK16_3) << 6; - l2 |= (tmp[tmpIdx + 11] & MASK16_3) << 3; - l2 |= (tmp[tmpIdx + 12] & MASK16_3) << 0; + l2 |= tmp[tmpIdx + 9] << 9; + l2 |= tmp[tmpIdx + 10] << 6; + l2 |= tmp[tmpIdx + 11] << 3; + l2 |= tmp[tmpIdx + 12] << 0; longs[longsIdx + 2] = l2; } } - private static void decode14(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 28); - shiftLongs(tmp, 28, longs, 0, 2, MASK16_14); - shiftLongs(tmp, 28, tmp, 0, 0, MASK16_2); + private static void decode14(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(28, longs, 2, MASK16_14, tmp, 0, MASK16_2); for (int iter = 0, tmpIdx = 0, longsIdx = 28; iter < 4; ++iter, tmpIdx += 7, longsIdx += 1) { long l0 = tmp[tmpIdx + 0] << 12; l0 |= tmp[tmpIdx + 1] << 10; @@ -833,10 +841,9 @@ final class ForUtil { } } - private static void decode15(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 30); - shiftLongs(tmp, 30, longs, 0, 1, MASK16_15); - shiftLongs(tmp, 30, tmp, 0, 0, MASK16_1); + private static void decode15(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(30, longs, 1, MASK16_15, tmp, 0, MASK16_1); for (int iter = 0, tmpIdx = 0, longsIdx = 30; iter < 2; ++iter, tmpIdx += 15, longsIdx += 1) { long l0 = tmp[tmpIdx + 0] << 14; l0 |= tmp[tmpIdx + 1] << 13; @@ -857,15 +864,16 @@ final class ForUtil { } } - private static void decode16(DataInput in, long[] tmp, long[] longs) throws IOException { + private static void decode16(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { in.readLongs(longs, 0, 32); } - private static void decode17(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 34); - shiftLongs(tmp, 34, longs, 0, 15, MASK32_17); + private static void decode17(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(34, longs, 15, MASK32_17, tmp, 0, MASK32_15); for (int iter = 0, tmpIdx = 0, longsIdx = 34; iter < 2; ++iter, tmpIdx += 17, longsIdx += 15) { - long l0 = (tmp[tmpIdx + 0] & MASK32_15) << 2; + long l0 = tmp[tmpIdx + 0] << 2; l0 |= (tmp[tmpIdx + 1] >>> 13) & MASK32_2; longs[longsIdx + 0] = l0; long l1 = (tmp[tmpIdx + 1] & MASK32_13) << 4; @@ -887,7 +895,7 @@ final class ForUtil { l6 |= (tmp[tmpIdx + 7] >>> 1) & MASK32_14; longs[longsIdx + 6] = l6; long l7 = (tmp[tmpIdx + 7] & MASK32_1) << 16; - l7 |= (tmp[tmpIdx + 8] & MASK32_15) << 1; + l7 |= tmp[tmpIdx + 8] << 1; l7 |= (tmp[tmpIdx + 9] >>> 14) & MASK32_1; longs[longsIdx + 7] = l7; long l8 = (tmp[tmpIdx + 9] & MASK32_14) << 3; @@ -909,16 +917,16 @@ final class ForUtil { l13 |= (tmp[tmpIdx + 15] >>> 2) & MASK32_13; longs[longsIdx + 13] = l13; long l14 = (tmp[tmpIdx + 15] & MASK32_2) << 15; - l14 |= (tmp[tmpIdx + 16] & MASK32_15) << 0; + l14 |= tmp[tmpIdx + 16] << 0; longs[longsIdx + 14] = l14; } } - private static void decode18(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 36); - shiftLongs(tmp, 36, longs, 0, 14, MASK32_18); + private static void decode18(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(36, longs, 14, MASK32_18, tmp, 0, MASK32_14); for (int iter = 0, tmpIdx = 0, longsIdx = 36; iter < 4; ++iter, tmpIdx += 9, longsIdx += 7) { - long l0 = (tmp[tmpIdx + 0] & MASK32_14) << 4; + long l0 = tmp[tmpIdx + 0] << 4; l0 |= (tmp[tmpIdx + 1] >>> 10) & MASK32_4; longs[longsIdx + 0] = l0; long l1 = (tmp[tmpIdx + 1] & MASK32_10) << 8; @@ -928,7 +936,7 @@ final class ForUtil { l2 |= (tmp[tmpIdx + 3] >>> 2) & MASK32_12; longs[longsIdx + 2] = l2; long l3 = (tmp[tmpIdx + 3] & MASK32_2) << 16; - l3 |= (tmp[tmpIdx + 4] & MASK32_14) << 2; + l3 |= tmp[tmpIdx + 4] << 2; l3 |= (tmp[tmpIdx + 5] >>> 12) & MASK32_2; longs[longsIdx + 3] = l3; long l4 = (tmp[tmpIdx + 5] & MASK32_12) << 6; @@ -938,206 +946,205 @@ final class ForUtil { l5 |= (tmp[tmpIdx + 7] >>> 4) & MASK32_10; longs[longsIdx + 5] = l5; long l6 = (tmp[tmpIdx + 7] & MASK32_4) << 14; - l6 |= (tmp[tmpIdx + 8] & MASK32_14) << 0; + l6 |= tmp[tmpIdx + 8] << 0; longs[longsIdx + 6] = l6; } } - private static void decode19(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 38); - shiftLongs(tmp, 38, longs, 0, 13, MASK32_19); + private static void decode19(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(38, longs, 13, MASK32_19, tmp, 0, MASK32_13); for (int iter = 0, tmpIdx = 0, longsIdx = 38; iter < 2; ++iter, tmpIdx += 19, longsIdx += 13) { - long l0 = (tmp[tmpIdx + 0] & MASK32_13) << 6; + long l0 = tmp[tmpIdx + 0] << 6; l0 |= (tmp[tmpIdx + 1] >>> 7) & MASK32_6; longs[longsIdx + 0] = l0; long l1 = (tmp[tmpIdx + 1] & MASK32_7) << 12; l1 |= (tmp[tmpIdx + 2] >>> 1) & MASK32_12; longs[longsIdx + 1] = l1; long l2 = (tmp[tmpIdx + 2] & MASK32_1) << 18; - l2 |= (tmp[tmpIdx + 3] & MASK32_13) << 5; + l2 |= tmp[tmpIdx + 3] << 5; l2 |= (tmp[tmpIdx + 4] >>> 8) & MASK32_5; longs[longsIdx + 2] = l2; long l3 = (tmp[tmpIdx + 4] & MASK32_8) << 11; l3 |= (tmp[tmpIdx + 5] >>> 2) & MASK32_11; longs[longsIdx + 3] = l3; long l4 = (tmp[tmpIdx + 5] & MASK32_2) << 17; - l4 |= (tmp[tmpIdx + 6] & MASK32_13) << 4; + l4 |= tmp[tmpIdx + 6] << 4; l4 |= (tmp[tmpIdx + 7] >>> 9) & MASK32_4; longs[longsIdx + 4] = l4; long l5 = (tmp[tmpIdx + 7] & MASK32_9) << 10; l5 |= (tmp[tmpIdx + 8] >>> 3) & MASK32_10; longs[longsIdx + 5] = l5; long l6 = (tmp[tmpIdx + 8] & MASK32_3) << 16; - l6 |= (tmp[tmpIdx + 9] & MASK32_13) << 3; + l6 |= tmp[tmpIdx + 9] << 3; l6 |= (tmp[tmpIdx + 10] >>> 10) & MASK32_3; longs[longsIdx + 6] = l6; long l7 = (tmp[tmpIdx + 10] & MASK32_10) << 9; l7 |= (tmp[tmpIdx + 11] >>> 4) & MASK32_9; longs[longsIdx + 7] = l7; long l8 = (tmp[tmpIdx + 11] & MASK32_4) << 15; - l8 |= (tmp[tmpIdx + 12] & MASK32_13) << 2; + l8 |= tmp[tmpIdx + 12] << 2; l8 |= (tmp[tmpIdx + 13] >>> 11) & MASK32_2; longs[longsIdx + 8] = l8; long l9 = (tmp[tmpIdx + 13] & MASK32_11) << 8; l9 |= (tmp[tmpIdx + 14] >>> 5) & MASK32_8; longs[longsIdx + 9] = l9; long l10 = (tmp[tmpIdx + 14] & MASK32_5) << 14; - l10 |= (tmp[tmpIdx + 15] & MASK32_13) << 1; + l10 |= tmp[tmpIdx + 15] << 1; l10 |= (tmp[tmpIdx + 16] >>> 12) & MASK32_1; longs[longsIdx + 10] = l10; long l11 = (tmp[tmpIdx + 16] & MASK32_12) << 7; l11 |= (tmp[tmpIdx + 17] >>> 6) & MASK32_7; longs[longsIdx + 11] = l11; long l12 = (tmp[tmpIdx + 17] & MASK32_6) << 13; - l12 |= (tmp[tmpIdx + 18] & MASK32_13) << 0; + l12 |= tmp[tmpIdx + 18] << 0; longs[longsIdx + 12] = l12; } } - private static void decode20(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 40); - shiftLongs(tmp, 40, longs, 0, 12, MASK32_20); + private static void decode20(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(40, longs, 12, MASK32_20, tmp, 0, MASK32_12); for (int iter = 0, tmpIdx = 0, longsIdx = 40; iter < 8; ++iter, tmpIdx += 5, longsIdx += 3) { - long l0 = (tmp[tmpIdx + 0] & MASK32_12) << 8; + long l0 = tmp[tmpIdx + 0] << 8; l0 |= (tmp[tmpIdx + 1] >>> 4) & MASK32_8; longs[longsIdx + 0] = l0; long l1 = (tmp[tmpIdx + 1] & MASK32_4) << 16; - l1 |= (tmp[tmpIdx + 2] & MASK32_12) << 4; + l1 |= tmp[tmpIdx + 2] << 4; l1 |= (tmp[tmpIdx + 3] >>> 8) & MASK32_4; longs[longsIdx + 1] = l1; long l2 = (tmp[tmpIdx + 3] & MASK32_8) << 12; - l2 |= (tmp[tmpIdx + 4] & MASK32_12) << 0; + l2 |= tmp[tmpIdx + 4] << 0; longs[longsIdx + 2] = l2; } } - private static void decode21(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 42); - shiftLongs(tmp, 42, longs, 0, 11, MASK32_21); + private static void decode21(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(42, longs, 11, MASK32_21, tmp, 0, MASK32_11); for (int iter = 0, tmpIdx = 0, longsIdx = 42; iter < 2; ++iter, tmpIdx += 21, longsIdx += 11) { - long l0 = (tmp[tmpIdx + 0] & MASK32_11) << 10; + long l0 = tmp[tmpIdx + 0] << 10; l0 |= (tmp[tmpIdx + 1] >>> 1) & MASK32_10; longs[longsIdx + 0] = l0; long l1 = (tmp[tmpIdx + 1] & MASK32_1) << 20; - l1 |= (tmp[tmpIdx + 2] & MASK32_11) << 9; + l1 |= tmp[tmpIdx + 2] << 9; l1 |= (tmp[tmpIdx + 3] >>> 2) & MASK32_9; longs[longsIdx + 1] = l1; long l2 = (tmp[tmpIdx + 3] & MASK32_2) << 19; - l2 |= (tmp[tmpIdx + 4] & MASK32_11) << 8; + l2 |= tmp[tmpIdx + 4] << 8; l2 |= (tmp[tmpIdx + 5] >>> 3) & MASK32_8; longs[longsIdx + 2] = l2; long l3 = (tmp[tmpIdx + 5] & MASK32_3) << 18; - l3 |= (tmp[tmpIdx + 6] & MASK32_11) << 7; + l3 |= tmp[tmpIdx + 6] << 7; l3 |= (tmp[tmpIdx + 7] >>> 4) & MASK32_7; longs[longsIdx + 3] = l3; long l4 = (tmp[tmpIdx + 7] & MASK32_4) << 17; - l4 |= (tmp[tmpIdx + 8] & MASK32_11) << 6; + l4 |= tmp[tmpIdx + 8] << 6; l4 |= (tmp[tmpIdx + 9] >>> 5) & MASK32_6; longs[longsIdx + 4] = l4; long l5 = (tmp[tmpIdx + 9] & MASK32_5) << 16; - l5 |= (tmp[tmpIdx + 10] & MASK32_11) << 5; + l5 |= tmp[tmpIdx + 10] << 5; l5 |= (tmp[tmpIdx + 11] >>> 6) & MASK32_5; longs[longsIdx + 5] = l5; long l6 = (tmp[tmpIdx + 11] & MASK32_6) << 15; - l6 |= (tmp[tmpIdx + 12] & MASK32_11) << 4; + l6 |= tmp[tmpIdx + 12] << 4; l6 |= (tmp[tmpIdx + 13] >>> 7) & MASK32_4; longs[longsIdx + 6] = l6; long l7 = (tmp[tmpIdx + 13] & MASK32_7) << 14; - l7 |= (tmp[tmpIdx + 14] & MASK32_11) << 3; + l7 |= tmp[tmpIdx + 14] << 3; l7 |= (tmp[tmpIdx + 15] >>> 8) & MASK32_3; longs[longsIdx + 7] = l7; long l8 = (tmp[tmpIdx + 15] & MASK32_8) << 13; - l8 |= (tmp[tmpIdx + 16] & MASK32_11) << 2; + l8 |= tmp[tmpIdx + 16] << 2; l8 |= (tmp[tmpIdx + 17] >>> 9) & MASK32_2; longs[longsIdx + 8] = l8; long l9 = (tmp[tmpIdx + 17] & MASK32_9) << 12; - l9 |= (tmp[tmpIdx + 18] & MASK32_11) << 1; + l9 |= tmp[tmpIdx + 18] << 1; l9 |= (tmp[tmpIdx + 19] >>> 10) & MASK32_1; longs[longsIdx + 9] = l9; long l10 = (tmp[tmpIdx + 19] & MASK32_10) << 11; - l10 |= (tmp[tmpIdx + 20] & MASK32_11) << 0; + l10 |= tmp[tmpIdx + 20] << 0; longs[longsIdx + 10] = l10; } } - private static void decode22(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 44); - shiftLongs(tmp, 44, longs, 0, 10, MASK32_22); + private static void decode22(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(44, longs, 10, MASK32_22, tmp, 0, MASK32_10); for (int iter = 0, tmpIdx = 0, longsIdx = 44; iter < 4; ++iter, tmpIdx += 11, longsIdx += 5) { - long l0 = (tmp[tmpIdx + 0] & MASK32_10) << 12; - l0 |= (tmp[tmpIdx + 1] & MASK32_10) << 2; + long l0 = tmp[tmpIdx + 0] << 12; + l0 |= tmp[tmpIdx + 1] << 2; l0 |= (tmp[tmpIdx + 2] >>> 8) & MASK32_2; longs[longsIdx + 0] = l0; long l1 = (tmp[tmpIdx + 2] & MASK32_8) << 14; - l1 |= (tmp[tmpIdx + 3] & MASK32_10) << 4; + l1 |= tmp[tmpIdx + 3] << 4; l1 |= (tmp[tmpIdx + 4] >>> 6) & MASK32_4; longs[longsIdx + 1] = l1; long l2 = (tmp[tmpIdx + 4] & MASK32_6) << 16; - l2 |= (tmp[tmpIdx + 5] & MASK32_10) << 6; + l2 |= tmp[tmpIdx + 5] << 6; l2 |= (tmp[tmpIdx + 6] >>> 4) & MASK32_6; longs[longsIdx + 2] = l2; long l3 = (tmp[tmpIdx + 6] & MASK32_4) << 18; - l3 |= (tmp[tmpIdx + 7] & MASK32_10) << 8; + l3 |= tmp[tmpIdx + 7] << 8; l3 |= (tmp[tmpIdx + 8] >>> 2) & MASK32_8; longs[longsIdx + 3] = l3; long l4 = (tmp[tmpIdx + 8] & MASK32_2) << 20; - l4 |= (tmp[tmpIdx + 9] & MASK32_10) << 10; - l4 |= (tmp[tmpIdx + 10] & MASK32_10) << 0; + l4 |= tmp[tmpIdx + 9] << 10; + l4 |= tmp[tmpIdx + 10] << 0; longs[longsIdx + 4] = l4; } } - private static void decode23(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 46); - shiftLongs(tmp, 46, longs, 0, 9, MASK32_23); + private static void decode23(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(46, longs, 9, MASK32_23, tmp, 0, MASK32_9); for (int iter = 0, tmpIdx = 0, longsIdx = 46; iter < 2; ++iter, tmpIdx += 23, longsIdx += 9) { - long l0 = (tmp[tmpIdx + 0] & MASK32_9) << 14; - l0 |= (tmp[tmpIdx + 1] & MASK32_9) << 5; + long l0 = tmp[tmpIdx + 0] << 14; + l0 |= tmp[tmpIdx + 1] << 5; l0 |= (tmp[tmpIdx + 2] >>> 4) & MASK32_5; longs[longsIdx + 0] = l0; long l1 = (tmp[tmpIdx + 2] & MASK32_4) << 19; - l1 |= (tmp[tmpIdx + 3] & MASK32_9) << 10; - l1 |= (tmp[tmpIdx + 4] & MASK32_9) << 1; + l1 |= tmp[tmpIdx + 3] << 10; + l1 |= tmp[tmpIdx + 4] << 1; l1 |= (tmp[tmpIdx + 5] >>> 8) & MASK32_1; longs[longsIdx + 1] = l1; long l2 = (tmp[tmpIdx + 5] & MASK32_8) << 15; - l2 |= (tmp[tmpIdx + 6] & MASK32_9) << 6; + l2 |= tmp[tmpIdx + 6] << 6; l2 |= (tmp[tmpIdx + 7] >>> 3) & MASK32_6; longs[longsIdx + 2] = l2; long l3 = (tmp[tmpIdx + 7] & MASK32_3) << 20; - l3 |= (tmp[tmpIdx + 8] & MASK32_9) << 11; - l3 |= (tmp[tmpIdx + 9] & MASK32_9) << 2; + l3 |= tmp[tmpIdx + 8] << 11; + l3 |= tmp[tmpIdx + 9] << 2; l3 |= (tmp[tmpIdx + 10] >>> 7) & MASK32_2; longs[longsIdx + 3] = l3; long l4 = (tmp[tmpIdx + 10] & MASK32_7) << 16; - l4 |= (tmp[tmpIdx + 11] & MASK32_9) << 7; + l4 |= tmp[tmpIdx + 11] << 7; l4 |= (tmp[tmpIdx + 12] >>> 2) & MASK32_7; longs[longsIdx + 4] = l4; long l5 = (tmp[tmpIdx + 12] & MASK32_2) << 21; - l5 |= (tmp[tmpIdx + 13] & MASK32_9) << 12; - l5 |= (tmp[tmpIdx + 14] & MASK32_9) << 3; + l5 |= tmp[tmpIdx + 13] << 12; + l5 |= tmp[tmpIdx + 14] << 3; l5 |= (tmp[tmpIdx + 15] >>> 6) & MASK32_3; longs[longsIdx + 5] = l5; long l6 = (tmp[tmpIdx + 15] & MASK32_6) << 17; - l6 |= (tmp[tmpIdx + 16] & MASK32_9) << 8; + l6 |= tmp[tmpIdx + 16] << 8; l6 |= (tmp[tmpIdx + 17] >>> 1) & MASK32_8; longs[longsIdx + 6] = l6; long l7 = (tmp[tmpIdx + 17] & MASK32_1) << 22; - l7 |= (tmp[tmpIdx + 18] & MASK32_9) << 13; - l7 |= (tmp[tmpIdx + 19] & MASK32_9) << 4; + l7 |= tmp[tmpIdx + 18] << 13; + l7 |= tmp[tmpIdx + 19] << 4; l7 |= (tmp[tmpIdx + 20] >>> 5) & MASK32_4; longs[longsIdx + 7] = l7; long l8 = (tmp[tmpIdx + 20] & MASK32_5) << 18; - l8 |= (tmp[tmpIdx + 21] & MASK32_9) << 9; - l8 |= (tmp[tmpIdx + 22] & MASK32_9) << 0; + l8 |= tmp[tmpIdx + 21] << 9; + l8 |= tmp[tmpIdx + 22] << 0; longs[longsIdx + 8] = l8; } } - private static void decode24(DataInput in, long[] tmp, long[] longs) throws IOException { - in.readLongs(tmp, 0, 48); - shiftLongs(tmp, 48, longs, 0, 8, MASK32_24); - shiftLongs(tmp, 48, tmp, 0, 0, MASK32_8); + private static void decode24(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) + throws IOException { + pdu.splitLongs(48, longs, 8, MASK32_24, tmp, 0, MASK32_8); for (int iter = 0, tmpIdx = 0, longsIdx = 48; iter < 16; ++iter, tmpIdx += 3, longsIdx += 1) { long l0 = tmp[tmpIdx + 0] << 16; l0 |= tmp[tmpIdx + 1] << 8; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene912/Lucene912PostingsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene912/Lucene912PostingsReader.java index 5e66a200929..491b4507cf1 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene912/Lucene912PostingsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene912/Lucene912PostingsReader.java @@ -352,6 +352,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { final IndexInput startDocIn; IndexInput docIn; + PostingIndexInput postingDocIn; final boolean indexHasFreq; final boolean indexHasPos; final boolean indexHasOffsetsOrPayloads; @@ -413,6 +414,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { if (docIn == null) { // lazy init docIn = startDocIn.clone(); + postingDocIn = new PostingIndexInput(docIn, forUtil); } prefetchPostings(docIn, termState); } @@ -446,7 +448,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { public int freq() throws IOException { if (freqFP != -1) { docIn.seek(freqFP); - pforUtil.decode(docIn, freqBuffer); + pforUtil.decode(postingDocIn, freqBuffer); freqFP = -1; } @@ -481,7 +483,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { private void refillFullBlock() throws IOException { assert docFreq - docCountUpto >= BLOCK_SIZE; - forDeltaUtil.decodeAndPrefixSum(docIn, prevDocID, docBuffer); + forDeltaUtil.decodeAndPrefixSum(postingDocIn, prevDocID, docBuffer); if (indexHasFreq) { if (needsFreq) { @@ -649,8 +651,11 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { final IndexInput startDocIn; IndexInput docIn; + PostingIndexInput postingDocIn; final IndexInput posIn; + final PostingIndexInput postingPosIn; final IndexInput payIn; + final PostingIndexInput postingPayIn; final BytesRef payload; final boolean indexHasFreq; @@ -718,10 +723,13 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { indexHasOffsetsOrPayloads = indexHasOffsets || indexHasPayloads; this.posIn = Lucene912PostingsReader.this.posIn.clone(); + postingPosIn = new PostingIndexInput(posIn, forUtil); if (indexHasOffsetsOrPayloads) { this.payIn = Lucene912PostingsReader.this.payIn.clone(); + postingPayIn = new PostingIndexInput(payIn, forUtil); } else { this.payIn = null; + postingPayIn = null; } if (indexHasOffsets) { offsetStartDeltaBuffer = new long[BLOCK_SIZE]; @@ -768,6 +776,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { if (docIn == null) { // lazy init docIn = startDocIn.clone(); + postingDocIn = new PostingIndexInput(docIn, forUtil); } prefetchPostings(docIn, termState); } @@ -830,8 +839,8 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { assert left >= 0; if (left >= BLOCK_SIZE) { - forDeltaUtil.decodeAndPrefixSum(docIn, prevDocID, docBuffer); - pforUtil.decode(docIn, freqBuffer); + forDeltaUtil.decodeAndPrefixSum(postingDocIn, prevDocID, docBuffer); + pforUtil.decode(postingDocIn, freqBuffer); docCountUpto += BLOCK_SIZE; } else if (docFreq == 1) { docBuffer[0] = singletonDocID; @@ -1110,11 +1119,11 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { } payloadByteUpto = 0; } else { - pforUtil.decode(posIn, posDeltaBuffer); + pforUtil.decode(postingPosIn, posDeltaBuffer); if (indexHasPayloads) { if (needsPayloads) { - pforUtil.decode(payIn, payloadLengthBuffer); + pforUtil.decode(postingPayIn, payloadLengthBuffer); int numBytes = payIn.readVInt(); if (numBytes > payloadBytes.length) { @@ -1133,8 +1142,8 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { if (indexHasOffsets) { if (needsOffsets) { - pforUtil.decode(payIn, offsetStartDeltaBuffer); - pforUtil.decode(payIn, offsetLengthBuffer); + pforUtil.decode(postingPayIn, offsetStartDeltaBuffer); + pforUtil.decode(postingPayIn, offsetLengthBuffer); } else { // this works, because when writing a vint block we always force the first length to be // written @@ -1217,7 +1226,8 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { final IndexInput startDocIn; - IndexInput docIn; + final IndexInput docIn; + final PostingIndexInput postingDocIn; final boolean indexHasFreq; final boolean indexHasPos; final boolean indexHasOffsetsOrPayloads; @@ -1248,7 +1258,6 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { public BlockImpactsDocsEnum(FieldInfo fieldInfo, IntBlockTermState termState) throws IOException { this.startDocIn = Lucene912PostingsReader.this.docIn; - this.docIn = null; indexHasFreq = fieldInfo.getIndexOptions().compareTo(IndexOptions.DOCS_AND_FREQS) >= 0; indexHasPos = fieldInfo.getIndexOptions().compareTo(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS) >= 0; @@ -1264,11 +1273,12 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { docFreq = termState.docFreq; if (docFreq > 1) { - if (docIn == null) { - // lazy init - docIn = startDocIn.clone(); - } + docIn = startDocIn.clone(); + postingDocIn = new PostingIndexInput(docIn, forUtil); prefetchPostings(docIn, termState); + } else { + docIn = null; + postingDocIn = null; } doc = -1; @@ -1302,7 +1312,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { public int freq() throws IOException { if (freqFP != -1) { docIn.seek(freqFP); - pforUtil.decode(docIn, freqBuffer); + pforUtil.decode(postingDocIn, freqBuffer); freqFP = -1; } return (int) freqBuffer[docBufferUpto - 1]; @@ -1338,7 +1348,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { assert left >= 0; if (left >= BLOCK_SIZE) { - forDeltaUtil.decodeAndPrefixSum(docIn, prevDocID, docBuffer); + forDeltaUtil.decodeAndPrefixSum(postingDocIn, prevDocID, docBuffer); if (indexHasFreq) { freqFP = docIn.getFilePointer(); @@ -1573,8 +1583,10 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { final IndexInput startDocIn; - IndexInput docIn; + final IndexInput docIn; + final PostingIndexInput postingDocIn; final IndexInput posIn; + final PostingIndexInput postingPosIn; final boolean indexHasFreq; final boolean indexHasPos; @@ -1628,7 +1640,6 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { public BlockImpactsPostingsEnum(FieldInfo fieldInfo, IntBlockTermState termState) throws IOException { this.startDocIn = Lucene912PostingsReader.this.docIn; - this.docIn = null; indexHasFreq = fieldInfo.getIndexOptions().compareTo(IndexOptions.DOCS_AND_FREQS) >= 0; indexHasPos = fieldInfo.getIndexOptions().compareTo(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS) >= 0; @@ -1641,6 +1652,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { indexHasOffsetsOrPayloads = indexHasOffsets || indexHasPayloads; this.posIn = Lucene912PostingsReader.this.posIn.clone(); + postingPosIn = new PostingIndexInput(posIn, forUtil); // We set the last element of docBuffer to NO_MORE_DOCS, it helps save conditionals in // advance() @@ -1651,11 +1663,12 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { totalTermFreq = termState.totalTermFreq; singletonDocID = termState.singletonDocID; if (docFreq > 1) { - if (docIn == null) { - // lazy init - docIn = startDocIn.clone(); - } + docIn = startDocIn.clone(); + postingDocIn = new PostingIndexInput(docIn, forUtil); prefetchPostings(docIn, termState); + } else { + docIn = null; + postingDocIn = null; } posIn.seek(posTermStartFP); level1PosEndFP = posTermStartFP; @@ -1707,8 +1720,8 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { assert left >= 0; if (left >= BLOCK_SIZE) { - forDeltaUtil.decodeAndPrefixSum(docIn, prevDocID, docBuffer); - pforUtil.decode(docIn, freqBuffer); + forDeltaUtil.decodeAndPrefixSum(postingDocIn, prevDocID, docBuffer); + pforUtil.decode(postingDocIn, freqBuffer); docCountUpto += BLOCK_SIZE; } else if (docFreq == 1) { docBuffer[0] = singletonDocID; @@ -1981,7 +1994,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase { } } } else { - pforUtil.decode(posIn, posDeltaBuffer); + pforUtil.decode(postingPosIn, posDeltaBuffer); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene912/Lucene912PostingsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene912/Lucene912PostingsWriter.java index b307080b215..b3c6503449a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene912/Lucene912PostingsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene912/Lucene912PostingsWriter.java @@ -16,10 +16,12 @@ */ package org.apache.lucene.codecs.lucene912; -import static org.apache.lucene.codecs.lucene912.Lucene912PostingsFormat.*; import static org.apache.lucene.codecs.lucene912.Lucene912PostingsFormat.BLOCK_SIZE; import static org.apache.lucene.codecs.lucene912.Lucene912PostingsFormat.DOC_CODEC; +import static org.apache.lucene.codecs.lucene912.Lucene912PostingsFormat.LEVEL1_MASK; +import static org.apache.lucene.codecs.lucene912.Lucene912PostingsFormat.META_CODEC; import static org.apache.lucene.codecs.lucene912.Lucene912PostingsFormat.PAY_CODEC; +import static org.apache.lucene.codecs.lucene912.Lucene912PostingsFormat.POS_CODEC; import static org.apache.lucene.codecs.lucene912.Lucene912PostingsFormat.TERMS_CODEC; import static org.apache.lucene.codecs.lucene912.Lucene912PostingsFormat.VERSION_CURRENT; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene912/PForUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene912/PForUtil.java index f4405ae66fa..4fbe7051f3b 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene912/PForUtil.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene912/PForUtil.java @@ -104,18 +104,18 @@ final class PForUtil { } /** Decode 128 integers into {@code ints}. */ - void decode(DataInput in, long[] longs) throws IOException { - final int token = Byte.toUnsignedInt(in.readByte()); + void decode(PostingIndexInput in, long[] longs) throws IOException { + final int token = Byte.toUnsignedInt(in.in.readByte()); final int bitsPerValue = token & 0x1f; final int numExceptions = token >>> 5; if (bitsPerValue == 0) { - Arrays.fill(longs, 0, ForUtil.BLOCK_SIZE, in.readVLong()); + Arrays.fill(longs, 0, ForUtil.BLOCK_SIZE, in.in.readVLong()); } else { - forUtil.decode(bitsPerValue, in, longs); + in.decode(bitsPerValue, longs); } for (int i = 0; i < numExceptions; ++i) { - longs[Byte.toUnsignedInt(in.readByte())] |= - Byte.toUnsignedLong(in.readByte()) << bitsPerValue; + longs[Byte.toUnsignedInt(in.in.readByte())] |= + Byte.toUnsignedLong(in.in.readByte()) << bitsPerValue; } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene912/PostingIndexInput.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene912/PostingIndexInput.java new file mode 100644 index 00000000000..88067b02858 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene912/PostingIndexInput.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene912; + +import java.io.IOException; +import org.apache.lucene.internal.vectorization.PostingDecodingUtil; +import org.apache.lucene.internal.vectorization.VectorizationProvider; +import org.apache.lucene.store.IndexInput; + +/** + * Wrapper around an {@link IndexInput} and a {@link ForUtil} that optionally optimizes decoding + * using vectorization. + */ +public final class PostingIndexInput { + + private static final VectorizationProvider VECTORIZATION_PROVIDER = + VectorizationProvider.getInstance(); + + public final IndexInput in; + public final ForUtil forUtil; + private final PostingDecodingUtil postingDecodingUtil; + + public PostingIndexInput(IndexInput in, ForUtil forUtil) throws IOException { + this.in = in; + this.forUtil = forUtil; + this.postingDecodingUtil = VECTORIZATION_PROVIDER.newPostingDecodingUtil(in); + } + + /** Decode 128 integers stored on {@code bitsPerValues} bits per value into {@code longs}. */ + public void decode(int bitsPerValue, long[] longs) throws IOException { + forUtil.decode(bitsPerValue, in, postingDecodingUtil, longs); + } + + /** + * Decode 128 integers stored on {@code bitsPerValues} bits per value, compute their prefix sum, + * and store results into {@code longs}. + */ + public void decodeAndPrefixSum(int bitsPerValue, long base, long[] longs) throws IOException { + forUtil.decodeAndPrefixSum(bitsPerValue, in, postingDecodingUtil, base, longs); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene912/gen_ForUtil.py b/lucene/core/src/java/org/apache/lucene/codecs/lucene912/gen_ForUtil.py index c6a33ceef53..5e993e26555 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene912/gen_ForUtil.py +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene912/gen_ForUtil.py @@ -43,17 +43,20 @@ HEADER = """// This file has been automatically generated, DO NOT EDIT package org.apache.lucene.codecs.lucene912; import java.io.IOException; -import org.apache.lucene.store.DataInput; +import org.apache.lucene.internal.vectorization.PostingDecodingUtil; import org.apache.lucene.store.DataOutput; +import org.apache.lucene.store.IndexInput; -// Inspired from https://fulmicoton.com/posts/bitpacking/ -// Encodes multiple integers in a long to get SIMD-like speedups. -// If bitsPerValue <= 8 then we pack 8 ints per long -// else if bitsPerValue <= 16 we pack 4 ints per long -// else we pack 2 ints per long -final class ForUtil { +/** + * Inspired from https://fulmicoton.com/posts/bitpacking/ + * Encodes multiple integers in a long to get SIMD-like speedups. + * If bitsPerValue <= 8 then we pack 8 ints per long + * else if bitsPerValue <= 16 we pack 4 ints per long + * else we pack 2 ints per long + */ +public final class ForUtil { - static final int BLOCK_SIZE = 128; + public static final int BLOCK_SIZE = 128; private static final int BLOCK_SIZE_LOG2 = 7; private static long expandMask32(long mask32) { @@ -324,13 +327,13 @@ final class ForUtil { return bitsPerValue << (BLOCK_SIZE_LOG2 - 3); } - private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) + private static void decodeSlow(int bitsPerValue, IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) throws IOException { final int numLongs = bitsPerValue << 1; - in.readLongs(tmp, 0, numLongs); final long mask = MASKS32[bitsPerValue]; - int longsIdx = 0; - int shift = 32 - bitsPerValue; + pdu.splitLongs(numLongs, longs, 32 - bitsPerValue, mask, tmp, 0, -1L); + int longsIdx = numLongs; + int shift = 32 - 2 * bitsPerValue; for (; shift >= 0; shift -= bitsPerValue) { shiftLongs(tmp, numLongs, longs, longsIdx, shift, mask); longsIdx += numLongs; @@ -366,31 +369,17 @@ final class ForUtil { } } + /** + * Likewise, but for a simple mask. + */ + private static void maskLongs(long[] a, int count, long[] b, int bi, long mask) { + for (int i = 0; i < count; ++i) { + b[bi + i] = a[i] & mask; + } + } + """ -def writeRemainderWithSIMDOptimize(bpv, next_primitive, remaining_bits_per_long, o, num_values, f): - iteration = 1 - num_longs = bpv * num_values / remaining_bits_per_long - while num_longs % 2 == 0 and num_values % 2 == 0: - num_longs /= 2 - num_values /= 2 - iteration *= 2 - - f.write(' shiftLongs(tmp, %d, tmp, 0, 0, MASK%d_%d);\n' % (iteration * num_longs, next_primitive, remaining_bits_per_long)) - f.write(' for (int iter = 0, tmpIdx = 0, longsIdx = %d; iter < %d; ++iter, tmpIdx += %d, longsIdx += %d) {\n' %(o, iteration, num_longs, num_values)) - tmp_idx = 0 - b = bpv - b -= remaining_bits_per_long - f.write(' long l0 = tmp[tmpIdx + %d] << %d;\n' %(tmp_idx, b)) - tmp_idx += 1 - while b >= remaining_bits_per_long: - b -= remaining_bits_per_long - f.write(' l0 |= tmp[tmpIdx + %d] << %d;\n' %(tmp_idx, b)) - tmp_idx += 1 - f.write(' longs[longsIdx + 0] = l0;\n') - f.write(' }\n') - - def writeRemainder(bpv, next_primitive, remaining_bits_per_long, o, num_values, f): iteration = 1 num_longs = bpv * num_values / remaining_bits_per_long @@ -406,14 +395,14 @@ def writeRemainder(bpv, next_primitive, remaining_bits_per_long, o, num_values, b = bpv if remaining_bits == 0: b -= remaining_bits_per_long - f.write(' long l%d = (tmp[tmpIdx + %d] & MASK%d_%d) << %d;\n' %(i, tmp_idx, next_primitive, remaining_bits_per_long, b)) + f.write(' long l%d = tmp[tmpIdx + %d] << %d;\n' %(i, tmp_idx, b)) else: b -= remaining_bits f.write(' long l%d = (tmp[tmpIdx + %d] & MASK%d_%d) << %d;\n' %(i, tmp_idx, next_primitive, remaining_bits, b)) tmp_idx += 1 while b >= remaining_bits_per_long: b -= remaining_bits_per_long - f.write(' l%d |= (tmp[tmpIdx + %d] & MASK%d_%d) << %d;\n' %(i, tmp_idx, next_primitive, remaining_bits_per_long, b)) + f.write(' l%d |= tmp[tmpIdx + %d] << %d;\n' %(i, tmp_idx, b)) tmp_idx += 1 if b > 0: f.write(' l%d |= (tmp[tmpIdx + %d] >>> %d) & MASK%d_%d;\n' %(i, tmp_idx, remaining_bits_per_long-b, next_primitive, b)) @@ -428,23 +417,30 @@ def writeDecode(bpv, f): next_primitive = 8 elif bpv <= 16: next_primitive = 16 - f.write(' private static void decode%d(DataInput in, long[] tmp, long[] longs) throws IOException {\n' %bpv) - num_values_per_long = 64 / next_primitive + f.write(' private static void decode%d(IndexInput in, PostingDecodingUtil pdu, long[] tmp, long[] longs) throws IOException {\n' %bpv) if bpv == next_primitive: f.write(' in.readLongs(longs, 0, %d);\n' %(bpv*2)) + elif bpv * 2 == next_primitive: + f.write(' pdu.splitLongs(%d, longs, %d, MASK%d_%d, longs, %d, MASK%d_%d);\n' %(bpv*2, next_primitive - bpv, next_primitive, bpv, bpv*2, next_primitive, next_primitive - bpv)) else: - f.write(' in.readLongs(tmp, 0, %d);\n' %(bpv*2)) - shift = next_primitive - bpv - o = 0 + num_values_per_long = 64 / next_primitive + f.write(' pdu.splitLongs(%d, longs, %d, MASK%d_%d, tmp, 0, MASK%d_%d);\n' %(bpv*2, next_primitive - bpv, next_primitive, bpv, next_primitive, next_primitive - bpv)) + + shift = next_primitive - 2 * bpv + o = 2 * bpv while shift >= 0: - f.write(' shiftLongs(tmp, %d, longs, %d, %d, MASK%d_%d);\n' %(bpv*2, o, shift, next_primitive, bpv)) + if shift == 0: + f.write(' maskLongs(tmp, %d, longs, %d, MASK%d_%d);\n' %(bpv*2, o, next_primitive, bpv)) + else: + f.write(' shiftLongs(tmp, %d, longs, %d, %d, MASK%d_%d);\n' %(bpv*2, o, shift, next_primitive, bpv)) o += bpv*2 shift -= bpv - if shift + bpv > 0: - if bpv % (next_primitive % bpv) == 0: - writeRemainderWithSIMDOptimize(bpv, next_primitive, shift + bpv, o, 128/num_values_per_long - o, f) - else: - writeRemainder(bpv, next_primitive, shift + bpv, o, 128/num_values_per_long - o, f) + remaining_bits = shift + bpv + if remaining_bits > 0: + if remaining_bits != next_primitive - bpv: + # values in tmp still have more bits per value than remaining_bits, clear the higher bits now + f.write(' maskLongs(tmp, %d, tmp, 0, MASK%d_%d);\n' %(bpv*2, next_primitive, remaining_bits)) + writeRemainder(bpv, next_primitive, remaining_bits, o, 128/num_values_per_long - o, f) f.write(' }\n') @@ -471,7 +467,7 @@ if __name__ == '__main__': f.write(""" /** Decode 128 integers into {@code longs}. */ - void decode(int bitsPerValue, DataInput in, long[] longs) throws IOException { + void decode(int bitsPerValue, IndexInput in, PostingDecodingUtil pdu, long[] longs) throws IOException { switch (bitsPerValue) { """) for bpv in range(1, MAX_SPECIALIZED_BITS_PER_VALUE+1): @@ -481,11 +477,11 @@ if __name__ == '__main__': elif bpv <= 16: next_primitive = 16 f.write(' case %d:\n' %bpv) - f.write(' decode%d(in, tmp, longs);\n' %bpv) + f.write(' decode%d(in, pdu, tmp, longs);\n' %bpv) f.write(' expand%d(longs);\n' %next_primitive) f.write(' break;\n') f.write(' default:\n') - f.write(' decodeSlow(bitsPerValue, in, tmp, longs);\n') + f.write(' decodeSlow(bitsPerValue, in, pdu, tmp, longs);\n') f.write(' expand32(longs);\n') f.write(' break;\n') f.write(' }\n') @@ -495,7 +491,7 @@ if __name__ == '__main__': /** * Delta-decode 128 integers into {@code longs}. */ - void decodeAndPrefixSum(int bitsPerValue, DataInput in, long base, long[] longs) throws IOException { + void decodeAndPrefixSum(int bitsPerValue, IndexInput in, PostingDecodingUtil pdu, long base, long[] longs) throws IOException { switch (bitsPerValue) { """) for bpv in range(1, MAX_SPECIALIZED_BITS_PER_VALUE+1): @@ -505,11 +501,11 @@ if __name__ == '__main__': elif bpv <= 16: next_primitive = 16 f.write(' case %d:\n' %bpv) - f.write(' decode%d(in, tmp, longs);\n' %bpv) + f.write(' decode%d(in, pdu, tmp, longs);\n' %bpv) f.write(' prefixSum%d(longs, base);\n' %next_primitive) f.write(' break;\n') f.write(' default:\n') - f.write(' decodeSlow(bitsPerValue, in, tmp, longs);\n') + f.write(' decodeSlow(bitsPerValue, in, pdu, tmp, longs);\n') f.write(' prefixSum32(longs, base);\n') f.write(' break;\n') f.write(' }\n') diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultPostingDecodingUtil.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultPostingDecodingUtil.java new file mode 100644 index 00000000000..8c68e87109c --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultPostingDecodingUtil.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import java.io.IOException; +import org.apache.lucene.store.IndexInput; + +final class DefaultPostingDecodingUtil extends PostingDecodingUtil { + + protected final IndexInput in; + + public DefaultPostingDecodingUtil(IndexInput in) { + this.in = in; + } + + @Override + public void splitLongs( + int count, long[] b, int bShift, long bMask, long[] c, int cIndex, long cMask) + throws IOException { + assert count <= 64; + in.readLongs(c, cIndex, count); + // The below loop is auto-vectorized + for (int i = 0; i < count; ++i) { + b[i] = (c[cIndex + i] >>> bShift) & bMask; + c[cIndex + i] &= cMask; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java index c5193aa23de..2127a594117 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java @@ -19,6 +19,7 @@ package org.apache.lucene.internal.vectorization; import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.store.IndexInput; /** Default provider returning scalar implementations. */ final class DefaultVectorizationProvider extends VectorizationProvider { @@ -38,4 +39,9 @@ final class DefaultVectorizationProvider extends VectorizationProvider { public FlatVectorsScorer getLucene99FlatVectorsScorer() { return DefaultFlatVectorScorer.INSTANCE; } + + @Override + public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) { + return new DefaultPostingDecodingUtil(input); + } } diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/PostingDecodingUtil.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/PostingDecodingUtil.java new file mode 100644 index 00000000000..d5928959e28 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/PostingDecodingUtil.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import java.io.IOException; + +/** Utility class to decode postings. */ +public abstract class PostingDecodingUtil { + + /** + * Read {@code count} longs. This number must not exceed 64. Apply shift {@code bShift} and mask + * {@code bMask} and store the result in {@code b} starting at offset 0. Apply mask {@code cMask} + * and store the result in {@code c} starting at offset {@code cIndex}. + */ + public abstract void splitLongs( + int count, long[] b, int bShift, long bMask, long[] c, int cIndex, long cMask) + throws IOException; +} diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java index a236c303eb4..eeb1830fc91 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java @@ -17,6 +17,7 @@ package org.apache.lucene.internal.vectorization; +import java.io.IOException; import java.lang.StackWalker.StackFrame; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; @@ -28,6 +29,7 @@ import java.util.function.Predicate; import java.util.logging.Logger; import java.util.stream.Stream; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Constants; import org.apache.lucene.util.VectorUtil; @@ -95,6 +97,9 @@ public abstract class VectorizationProvider { /** Returns a FlatVectorsScorer that supports the Lucene99 format. */ public abstract FlatVectorsScorer getLucene99FlatVectorsScorer(); + /** Create a new {@link PostingDecodingUtil} for the given {@link IndexInput}. */ + public abstract PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException; + // *** Lookup mechanism: *** private static final Logger LOG = Logger.getLogger(VectorizationProvider.class.getName()); @@ -184,7 +189,8 @@ public abstract class VectorizationProvider { private static final Set VALID_CALLERS = Set.of( "org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil", - "org.apache.lucene.util.VectorUtil"); + "org.apache.lucene.util.VectorUtil", + "org.apache.lucene.codecs.lucene912.PostingIndexInput"); private static void ensureCaller() { final boolean validCaller = diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/MemorySegmentPostingDecodingUtil.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/MemorySegmentPostingDecodingUtil.java new file mode 100644 index 00000000000..7b4bc32bccf --- /dev/null +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/MemorySegmentPostingDecodingUtil.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.nio.ByteOrder; +import jdk.incubator.vector.LongVector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorSpecies; +import org.apache.lucene.store.IndexInput; + +final class MemorySegmentPostingDecodingUtil extends PostingDecodingUtil { + + private static final VectorSpecies LONG_SPECIES = + PanamaVectorConstants.PRERERRED_LONG_SPECIES; + + private final IndexInput in; + private final MemorySegment memorySegment; + + MemorySegmentPostingDecodingUtil(IndexInput in, MemorySegment memorySegment) { + this.in = in; + this.memorySegment = memorySegment; + } + + @Override + public void splitLongs( + int count, long[] b, int bShift, long bMask, long[] c, int cIndex, long cMask) + throws IOException { + if (count < LONG_SPECIES.length()) { + // Not enough data to vectorize without going out-of-bounds. In practice, this branch is never + // used if the bit width is 256, and is used for 2 and 3 bits per value if the bit width is + // 512. + in.readLongs(c, cIndex, count); + for (int i = 0; i < count; ++i) { + b[i] = (c[cIndex + i] >>> bShift) & bMask; + c[cIndex + i] &= cMask; + } + } else { + long offset = in.getFilePointer(); + long endOffset = offset + count * Long.BYTES; + int loopBound = LONG_SPECIES.loopBound(count - 1); + for (int i = 0; + i < loopBound; + i += LONG_SPECIES.length(), offset += LONG_SPECIES.length() * Long.BYTES) { + LongVector vector = + LongVector.fromMemorySegment( + LONG_SPECIES, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + vector + .lanewise(VectorOperators.LSHR, bShift) + .lanewise(VectorOperators.AND, bMask) + .intoArray(b, i); + vector.lanewise(VectorOperators.AND, cMask).intoArray(c, cIndex + i); + } + + // Handle the tail by reading a vector that is aligned with with `count` on the right side. + int i = count - LONG_SPECIES.length(); + offset = endOffset - LONG_SPECIES.length() * Long.BYTES; + LongVector vector = + LongVector.fromMemorySegment( + LONG_SPECIES, memorySegment, offset, ByteOrder.LITTLE_ENDIAN); + vector + .lanewise(VectorOperators.LSHR, bShift) + .lanewise(VectorOperators.AND, bMask) + .intoArray(b, i); + vector.lanewise(VectorOperators.AND, cMask).intoArray(c, cIndex + i); + + in.seek(endOffset); + } + } +} diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorConstants.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorConstants.java new file mode 100644 index 00000000000..e0c5bbca38e --- /dev/null +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorConstants.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import jdk.incubator.vector.VectorShape; +import jdk.incubator.vector.VectorSpecies; +import org.apache.lucene.util.Constants; + +/** Shared constants for implementations that take advantage of the Panama Vector API. */ +final class PanamaVectorConstants { + + /** Preferred width in bits for vectors. */ + static final int PREFERRED_VECTOR_BITSIZE; + + /** Whether integer vectors can be trusted to actually be fast. */ + static final boolean HAS_FAST_INTEGER_VECTORS; + + static final VectorSpecies PRERERRED_LONG_SPECIES; + static final VectorSpecies PRERERRED_INT_SPECIES; + + static { + // default to platform supported bitsize + int vectorBitSize = VectorShape.preferredShape().vectorBitSize(); + // but allow easy overriding for testing + PREFERRED_VECTOR_BITSIZE = VectorizationProvider.TESTS_VECTOR_SIZE.orElse(vectorBitSize); + + // hotspot misses some SSE intrinsics, workaround it + // to be fair, they do document this thing only works well with AVX2/AVX3 and Neon + boolean isAMD64withoutAVX2 = + Constants.OS_ARCH.equals("amd64") && PREFERRED_VECTOR_BITSIZE < 256; + HAS_FAST_INTEGER_VECTORS = + VectorizationProvider.TESTS_FORCE_INTEGER_VECTORS || (isAMD64withoutAVX2 == false); + + PRERERRED_LONG_SPECIES = + VectorSpecies.of(long.class, VectorShape.forBitSize(PREFERRED_VECTOR_BITSIZE)); + PRERERRED_INT_SPECIES = + VectorSpecies.of(int.class, VectorShape.forBitSize(PREFERRED_VECTOR_BITSIZE)); + } + + private PanamaVectorConstants() {} +} diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index 867d0c684cb..ad2dff11cea 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -52,20 +52,15 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { // preferred vector sizes, which can be altered for testing private static final VectorSpecies FLOAT_SPECIES; - private static final VectorSpecies INT_SPECIES; + private static final VectorSpecies INT_SPECIES = + PanamaVectorConstants.PRERERRED_INT_SPECIES; private static final VectorSpecies BYTE_SPECIES; private static final VectorSpecies SHORT_SPECIES; static final int VECTOR_BITSIZE; - static final boolean HAS_FAST_INTEGER_VECTORS; static { - // default to platform supported bitsize - int vectorBitSize = VectorShape.preferredShape().vectorBitSize(); - // but allow easy overriding for testing - vectorBitSize = VectorizationProvider.TESTS_VECTOR_SIZE.orElse(vectorBitSize); - INT_SPECIES = VectorSpecies.of(int.class, VectorShape.forBitSize(vectorBitSize)); - VECTOR_BITSIZE = INT_SPECIES.vectorBitSize(); + VECTOR_BITSIZE = PanamaVectorConstants.PREFERRED_VECTOR_BITSIZE; FLOAT_SPECIES = INT_SPECIES.withLanes(float.class); // compute BYTE/SHORT sizes relative to preferred integer vector size if (VECTOR_BITSIZE >= 256) { @@ -76,11 +71,6 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { BYTE_SPECIES = null; SHORT_SPECIES = null; } - // hotspot misses some SSE intrinsics, workaround it - // to be fair, they do document this thing only works well with AVX2/AVX3 and Neon - boolean isAMD64withoutAVX2 = Constants.OS_ARCH.equals("amd64") && VECTOR_BITSIZE < 256; - HAS_FAST_INTEGER_VECTORS = - VectorizationProvider.TESTS_FORCE_INTEGER_VECTORS || (isAMD64withoutAVX2 == false); } // the way FMA should work! if available use it, otherwise fall back to mul/add @@ -320,7 +310,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit // vectors (256-bit on intel to dodge performance landmines) - if (a.byteSize() >= 16 && HAS_FAST_INTEGER_VECTORS) { + if (a.byteSize() >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) { // compute vectorized dot product consistent with VPDPBUSD instruction if (VECTOR_BITSIZE >= 512) { i += BYTE_SPECIES.loopBound(a.byteSize()); @@ -414,7 +404,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { } else if (VECTOR_BITSIZE == 256) { i += ByteVector.SPECIES_128.loopBound(packed.length); res += dotProductBody256Int4Packed(unpacked, packed, i); - } else if (HAS_FAST_INTEGER_VECTORS) { + } else if (PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) { i += ByteVector.SPECIES_64.loopBound(packed.length); res += dotProductBody128Int4Packed(unpacked, packed, i); } @@ -430,7 +420,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { } else { if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) { return dotProduct(a, b); - } else if (a.length >= 32 && HAS_FAST_INTEGER_VECTORS) { + } else if (a.length >= 32 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) { i += ByteVector.SPECIES_128.loopBound(a.length); res += int4DotProductBody128(a, b, i); } @@ -588,7 +578,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit // vectors (256-bit on intel to dodge performance landmines) - if (a.byteSize() >= 16 && HAS_FAST_INTEGER_VECTORS) { + if (a.byteSize() >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) { final float[] ret; if (VECTOR_BITSIZE >= 512) { i += BYTE_SPECIES.loopBound((int) a.byteSize()); @@ -711,7 +701,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit // vectors (256-bit on intel to dodge performance landmines) - if (a.byteSize() >= 16 && HAS_FAST_INTEGER_VECTORS) { + if (a.byteSize() >= 16 && PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS) { if (VECTOR_BITSIZE >= 256) { i += BYTE_SPECIES.loopBound((int) a.byteSize()); res += squareDistanceBody256(a, b, i); diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java index 87f7cf2baf7..0e060586c2a 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java @@ -16,19 +16,25 @@ */ package org.apache.lucene.internal.vectorization; +import java.io.IOException; +import java.lang.foreign.MemorySegment; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.Locale; import java.util.logging.Logger; import jdk.incubator.vector.FloatVector; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; import org.apache.lucene.util.Constants; import org.apache.lucene.util.SuppressForbidden; /** A vectorization provider that leverages the Panama Vector API. */ final class PanamaVectorizationProvider extends VectorizationProvider { - private final VectorUtilSupport vectorUtilSupport; + // NOTE: Avoid static fields or initializers which rely on the vector API, as these initializers + // would get called before we have a chance to perform sanity checks around the vector API in the + // constructor of this class. Put them in PanamaVectorConstants instead. // Extracted to a method to be able to apply the SuppressForbidden annotation @SuppressWarnings("removal") @@ -37,6 +43,8 @@ final class PanamaVectorizationProvider extends VectorizationProvider { return AccessController.doPrivileged(action); } + private final VectorUtilSupport vectorUtilSupport; + PanamaVectorizationProvider() { // hack to work around for JDK-8309727: try { @@ -51,9 +59,9 @@ final class PanamaVectorizationProvider extends VectorizationProvider { "We hit initialization failure described in JDK-8309727: " + se); } - if (PanamaVectorUtilSupport.VECTOR_BITSIZE < 128) { + if (PanamaVectorConstants.PREFERRED_VECTOR_BITSIZE < 128) { throw new UnsupportedOperationException( - "Vector bit size is less than 128: " + PanamaVectorUtilSupport.VECTOR_BITSIZE); + "Vector bit size is less than 128: " + PanamaVectorConstants.PREFERRED_VECTOR_BITSIZE); } this.vectorUtilSupport = new PanamaVectorUtilSupport(); @@ -63,11 +71,9 @@ final class PanamaVectorizationProvider extends VectorizationProvider { String.format( Locale.ENGLISH, "Java vector incubator API enabled; uses preferredBitSize=%d%s%s", - PanamaVectorUtilSupport.VECTOR_BITSIZE, + PanamaVectorConstants.PREFERRED_VECTOR_BITSIZE, Constants.HAS_FAST_VECTOR_FMA ? "; FMA enabled" : "", - PanamaVectorUtilSupport.HAS_FAST_INTEGER_VECTORS - ? "" - : "; floating-point vectors only")); + PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS ? "" : "; floating-point vectors only")); } @Override @@ -79,4 +85,16 @@ final class PanamaVectorizationProvider extends VectorizationProvider { public FlatVectorsScorer getLucene99FlatVectorsScorer() { return Lucene99MemorySegmentFlatVectorsScorer.INSTANCE; } + + @Override + public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException { + if (PanamaVectorConstants.HAS_FAST_INTEGER_VECTORS + && input instanceof MemorySegmentAccessInput msai) { + MemorySegment ms = msai.segmentSliceOrNull(0, input.length()); + if (ms != null) { + return new MemorySegmentPostingDecodingUtil(input, ms); + } + } + return new DefaultPostingDecodingUtil(input); + } } diff --git a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentAccessInput.java b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentAccessInput.java index 7c22eccdcf1..8b6452a748b 100644 --- a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentAccessInput.java +++ b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentAccessInput.java @@ -27,7 +27,7 @@ import java.lang.foreign.MemorySegment; public interface MemorySegmentAccessInput extends RandomAccessInput, Cloneable { /** Returns the memory segment for a given position and length, or null. */ - MemorySegment segmentSliceOrNull(long pos, int len) throws IOException; + MemorySegment segmentSliceOrNull(long pos, long len) throws IOException; MemorySegmentAccessInput clone(); } diff --git a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java index e9805f0f7a6..c6ac3d23a12 100644 --- a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java +++ b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java @@ -742,7 +742,7 @@ abstract class MemorySegmentIndexInput extends IndexInput } @Override - public MemorySegment segmentSliceOrNull(long pos, int len) throws IOException { + public MemorySegment segmentSliceOrNull(long pos, long len) throws IOException { try { Objects.checkIndex(pos + len, this.length + 1); return curSegment.asSlice(pos, len); @@ -816,7 +816,8 @@ abstract class MemorySegmentIndexInput extends IndexInput return super.readLong(pos + offset); } - public MemorySegment segmentSliceOrNull(long pos, int len) throws IOException { + @Override + public MemorySegment segmentSliceOrNull(long pos, long len) throws IOException { if (pos + len > length) { throw handlePositionalIOOBE(null, "segmentSliceOrNull", pos); } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene912/TestForDeltaUtil.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene912/TestForDeltaUtil.java index 3c201ce6835..93ad6b3b6b2 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene912/TestForDeltaUtil.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene912/TestForDeltaUtil.java @@ -64,11 +64,13 @@ public class TestForDeltaUtil extends LuceneTestCase { { // decode IndexInput in = d.openInput("test.bin", IOContext.READONCE); - final ForDeltaUtil forDeltaUtil = new ForDeltaUtil(new ForUtil()); + ForUtil forUtil = new ForUtil(); + PostingIndexInput postingIn = new PostingIndexInput(in, forUtil); + final ForDeltaUtil forDeltaUtil = new ForDeltaUtil(forUtil); for (int i = 0; i < iterations; ++i) { long base = 0; final long[] restored = new long[ForUtil.BLOCK_SIZE]; - forDeltaUtil.decodeAndPrefixSum(in, base, restored); + forDeltaUtil.decodeAndPrefixSum(postingIn, base, restored); final long[] expected = new long[ForUtil.BLOCK_SIZE]; for (int j = 0; j < ForUtil.BLOCK_SIZE; ++j) { expected[j] = values[i * ForUtil.BLOCK_SIZE + j]; diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene912/TestForUtil.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene912/TestForUtil.java index 114a9d0415c..dec60fc6762 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene912/TestForUtil.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene912/TestForUtil.java @@ -69,12 +69,13 @@ public class TestForUtil extends LuceneTestCase { { // decode IndexInput in = d.openInput("test.bin", IOContext.READONCE); - final ForUtil forUtil = new ForUtil(); + ForUtil forUtil = new ForUtil(); + PostingIndexInput postingIn = new PostingIndexInput(in, forUtil); for (int i = 0; i < iterations; ++i) { final int bitsPerValue = in.readByte(); final long currentFilePointer = in.getFilePointer(); final long[] restored = new long[ForUtil.BLOCK_SIZE]; - forUtil.decode(bitsPerValue, in, restored); + postingIn.decode(bitsPerValue, restored); int[] ints = new int[ForUtil.BLOCK_SIZE]; for (int j = 0; j < ForUtil.BLOCK_SIZE; ++j) { ints[j] = Math.toIntExact(restored[j]); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene912/TestPForUtil.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene912/TestPForUtil.java index 08fec7a3a33..d185aff7645 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene912/TestPForUtil.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene912/TestPForUtil.java @@ -38,15 +38,17 @@ public class TestPForUtil extends LuceneTestCase { final Directory d = new ByteBuffersDirectory(); final long endPointer = encodeTestData(iterations, values, d); + ForUtil forUtil = new ForUtil(); IndexInput in = d.openInput("test.bin", IOContext.READONCE); - final PForUtil pforUtil = new PForUtil(new ForUtil()); + PostingIndexInput postingIn = new PostingIndexInput(in, forUtil); + final PForUtil pforUtil = new PForUtil(forUtil); for (int i = 0; i < iterations; ++i) { if (random().nextInt(5) == 0) { pforUtil.skip(in); continue; } final long[] restored = new long[ForUtil.BLOCK_SIZE]; - pforUtil.decode(in, restored); + pforUtil.decode(postingIn, restored); int[] ints = new int[ForUtil.BLOCK_SIZE]; for (int j = 0; j < ForUtil.BLOCK_SIZE; ++j) { ints[j] = Math.toIntExact(restored[j]); diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestPostingDecodingUtil.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestPostingDecodingUtil.java new file mode 100644 index 00000000000..64d1c23930e --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestPostingDecodingUtil.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import org.apache.lucene.codecs.lucene912.ForUtil; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.tests.util.TestUtil; + +public class TestPostingDecodingUtil extends LuceneTestCase { + + public void testDuelSplitLongs() throws Exception { + final int iterations = atLeast(100); + + try (Directory dir = new MMapDirectory(createTempDir())) { + try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) { + out.writeInt(random().nextInt()); + for (int i = 0; i < ForUtil.BLOCK_SIZE; ++i) { + out.writeLong(random().nextInt()); + } + } + VectorizationProvider vectorizationProvider = VectorizationProvider.lookup(true); + try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) { + long[] expectedB = new long[ForUtil.BLOCK_SIZE]; + long[] expectedC = new long[ForUtil.BLOCK_SIZE]; + long[] actualB = new long[ForUtil.BLOCK_SIZE]; + long[] actualC = new long[ForUtil.BLOCK_SIZE]; + for (int iter = 0; iter < iterations; ++iter) { + // Initialize arrays with random content. + for (int i = 0; i < expectedB.length; ++i) { + expectedB[i] = random().nextLong(); + actualB[i] = expectedB[i]; + expectedC[i] = random().nextLong(); + actualC[i] = expectedC[i]; + } + int count = TestUtil.nextInt(random(), 1, 64); + int bShift = TestUtil.nextInt(random(), 1, 31); + long bMask = random().nextLong(); + int cIndex = random().nextInt(64); + long cMask = random().nextLong(); + long startFP = random().nextInt(4); + + // Work on a slice that has just the right number of bytes to make the test fail with an + // index-out-of-bounds in case the implementation reads more than the allowed number of + // padding bytes. + IndexInput slice = in.slice("test", 0, startFP + count * Long.BYTES); + + PostingDecodingUtil defaultUtil = new DefaultPostingDecodingUtil(slice); + PostingDecodingUtil optimizedUtil = vectorizationProvider.newPostingDecodingUtil(slice); + + slice.seek(startFP); + defaultUtil.splitLongs(count, expectedB, bShift, bMask, expectedC, cIndex, cMask); + long expectedEndFP = slice.getFilePointer(); + slice.seek(startFP); + optimizedUtil.splitLongs(count, actualB, bShift, bMask, actualC, cIndex, cMask); + assertEquals(expectedEndFP, slice.getFilePointer()); + assertArrayEquals(expectedB, actualB); + assertArrayEquals(expectedC, actualC); + } + } + } + } +}