From 6bba35a7096ef2d4023ba9339e9cc30268d520ce Mon Sep 17 00:00:00 2001 From: Bruno Roustant Date: Thu, 9 Apr 2020 10:36:37 +0200 Subject: [PATCH] LUCENE-9286: FST.Arc.BitTable reads directly FST bytes. Arc is lightweight again and FSTEnum traversal faster. --- lucene/CHANGES.txt | 3 + .../java/org/apache/lucene/util/BitUtil.java | 128 +----- .../apache/lucene/util/fst/BitTableUtil.java | 179 +++++++++ .../java/org/apache/lucene/util/fst/FST.java | 367 ++++++------------ .../org/apache/lucene/util/fst/FSTEnum.java | 17 +- .../org/apache/lucene/util/fst/NodeHash.java | 2 +- .../java/org/apache/lucene/util/fst/Util.java | 19 +- .../org/apache/lucene/util/TestBitUtil.java | 87 ----- .../lucene/util/fst/TestBitTableUtil.java | 138 +++++++ .../util/fst/TestFSTDirectAddressing.java | 85 +++- 10 files changed, 548 insertions(+), 477 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/util/fst/BitTableUtil.java delete mode 100644 lucene/core/src/test/org/apache/lucene/util/TestBitUtil.java create mode 100644 lucene/core/src/test/org/apache/lucene/util/fst/TestBitTableUtil.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 38c262fd285..363d823fe65 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -144,6 +144,9 @@ Optimizations * LUCENE-9287: UsageTrackingQueryCachingPolicy no longer caches DocValuesFieldExistsQuery. (Ignacio Vera) +* LUCENE-9286: FST.Arc.BitTable reads directly FST bytes. Arc is lightweight again and FSTEnum traversal faster. + (Bruno Roustant) + Bug Fixes --------------------- * LUCENE-9259: Fix wrong NGramFilterFactory argument name for preserveOriginal option (Paul Pazderski) diff --git a/lucene/core/src/java/org/apache/lucene/util/BitUtil.java b/lucene/core/src/java/org/apache/lucene/util/BitUtil.java index 65f69b06c92..e308a9aaebe 100644 --- a/lucene/core/src/java/org/apache/lucene/util/BitUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/BitUtil.java @@ -155,7 +155,7 @@ public final class BitUtil { /** * flip flops odd with even bits */ - public static final long flipFlop(final long b) { + public static long flipFlop(final long b) { return ((b & MAGIC6) >>> 1) | ((b & MAGIC0) << 1 ); } @@ -183,130 +183,4 @@ public final class BitUtil { public static long zigZagDecode(long l) { return ((l >>> 1) ^ -(l & 1)); } - - /** - * Returns whether the bit at given zero-based index is set. - *
Example: bitIndex 66 means the third bit on the right of the second long. - * - * @param bits The bits stored in an array of long for efficiency. - * @param numLongs The number of longs in {@code bits} to consider. - * @param bitIndex The bit zero-based index. It must be greater than or equal to 0, - * and strictly less than {@code numLongs * Long.SIZE}. - */ - public static boolean isBitSet(long[] bits, int numLongs, int bitIndex) { - assert numLongs >= 0 && numLongs <= bits.length && bitIndex >= 0 && bitIndex < numLongs * Long.SIZE - : "bitIndex=" + bitIndex + " numLongs=" + numLongs + " bits.length=" + bits.length; - return (bits[bitIndex / Long.SIZE] & (1L << bitIndex)) != 0; // Shifts are mod 64. - } - - /** - * Counts all bits set in the provided longs. - * - * @param bits The bits stored in an array of long for efficiency. - * @param numLongs The number of longs in {@code bits} to consider. - */ - public static int countBits(long[] bits, int numLongs) { - assert numLongs >= 0 && numLongs <= bits.length - : "numLongs=" + numLongs + " bits.length=" + bits.length; - int bitCount = 0; - for (int i = 0; i < numLongs; i++) { - bitCount += Long.bitCount(bits[i]); - } - return bitCount; - } - - /** - * Counts the bits set up to the given bit zero-based index, exclusive. - *
In other words, how many 1s there are up to the bit at the given index excluded. - *
Example: bitIndex 66 means the third bit on the right of the second long. - * - * @param bits The bits stored in an array of long for efficiency. - * @param numLongs The number of longs in {@code bits} to consider. - * @param bitIndex The bit zero-based index, exclusive. It must be greater than or equal to 0, - * and less than or equal to {@code numLongs * Long.SIZE}. - */ - public static int countBitsUpTo(long[] bits, int numLongs, int bitIndex) { - assert numLongs >= 0 && numLongs <= bits.length && bitIndex >= 0 && bitIndex <= numLongs * Long.SIZE - : "bitIndex=" + bitIndex + " numLongs=" + numLongs + " bits.length=" + bits.length; - int bitCount = 0; - int lastLong = bitIndex / Long.SIZE; - for (int i = 0; i < lastLong; i++) { - // Count the bits set for all plain longs. - bitCount += Long.bitCount(bits[i]); - } - if (lastLong < numLongs) { - // Prepare a mask with 1s on the right up to bitIndex exclusive. - long mask = (1L << bitIndex) - 1L; // Shifts are mod 64. - // Count the bits set only within the mask part, so up to bitIndex exclusive. - bitCount += Long.bitCount(bits[lastLong] & mask); - } - return bitCount; - } - - /** - * Returns the index of the next bit set following the given bit zero-based index. - *
For example with bits 100011: - * the next bit set after index=-1 is at index=0; - * the next bit set after index=0 is at index=1; - * the next bit set after index=1 is at index=5; - * there is no next bit set after index=5. - * - * @param bits The bits stored in an array of long for efficiency. - * @param numLongs The number of longs in {@code bits} to consider. - * @param bitIndex The bit zero-based index. It must be greater than or equal to -1, - * and strictly less than {@code numLongs * Long.SIZE}. - * @return The zero-based index of the next bit set after the provided {@code bitIndex}; - * or -1 if none. - */ - public static int nextBitSet(long[] bits, int numLongs, int bitIndex) { - assert numLongs >= 0 && numLongs <= bits.length && bitIndex >= -1 && bitIndex < numLongs * Long.SIZE - : "bitIndex=" + bitIndex + " numLongs=" + numLongs + " bits.length=" + bits.length; - int longIndex = bitIndex / Long.SIZE; - // Prepare a mask with 1s on the left down to bitIndex exclusive. - long mask = -(1L << (bitIndex + 1)); // Shifts are mod 64. - long l = mask == -1 && bitIndex != -1 ? 0 : bits[longIndex] & mask; - while (l == 0) { - if (++longIndex == numLongs) { - return -1; - } - l = bits[longIndex]; - } - return Long.numberOfTrailingZeros(l) + longIndex * 64; - } - - /** - * Returns the index of the previous bit set preceding the given bit zero-based index. - *
For example with bits 100011: - * there is no previous bit set before index=0. - * the previous bit set before index=1 is at index=0; - * the previous bit set before index=5 is at index=1; - * the previous bit set before index=64 is at index=5; - * - * @param bits The bits stored in an array of long for efficiency. - * @param numLongs The number of longs in {@code bits} to consider. - * @param bitIndex The bit zero-based index. It must be greater than or equal to 0, - * and less than or equal to {@code numLongs * Long.SIZE}. - * @return The zero-based index of the previous bit set before the provided {@code bitIndex}; - * or -1 if none. - */ - public static int previousBitSet(long[] bits, int numLongs, int bitIndex) { - assert numLongs >= 0 && numLongs <= bits.length && bitIndex >= 0 && bitIndex <= numLongs * Long.SIZE - : "bitIndex=" + bitIndex + " numLongs=" + numLongs + " bits.length=" + bits.length; - int longIndex = bitIndex / Long.SIZE; - long l; - if (longIndex == numLongs) { - l = 0; - } else { - // Prepare a mask with 1s on the right up to bitIndex exclusive. - long mask = (1L << bitIndex) - 1L; // Shifts are mod 64. - l = bits[longIndex] & mask; - } - while (l == 0) { - if (longIndex-- == 0) { - return -1; - } - l = bits[longIndex]; - } - return 63 - Long.numberOfLeadingZeros(l) + longIndex * 64; - } } diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/BitTableUtil.java b/lucene/core/src/java/org/apache/lucene/util/fst/BitTableUtil.java new file mode 100644 index 00000000000..ee59cd4dfbf --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/fst/BitTableUtil.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.fst; + +import java.io.IOException; + +/** + * Static helper methods for {@link FST.Arc.BitTable}. + * + * @lucene.experimental + */ +class BitTableUtil { + + /** + * Returns whether the bit at given zero-based index is set. + *
Example: bitIndex 10 means the third bit on the right of the second byte. + * + * @param bitIndex The bit zero-based index. It must be greater than or equal to 0, and strictly less than + * {@code number of bit-table bytes * Byte.SIZE}. + * @param reader The {@link FST.BytesReader} to read. It must be positioned at the beginning of the bit-table. + */ + static boolean isBitSet(int bitIndex, FST.BytesReader reader) throws IOException { + assert bitIndex >= 0 : "bitIndex=" + bitIndex; + reader.skipBytes(bitIndex >> 3); + return (readByte(reader) & (1L << (bitIndex & (Byte.SIZE - 1)))) != 0; + } + + + /** + * Counts all bits set in the bit-table. + * + * @param bitTableBytes The number of bytes in the bit-table. + * @param reader The {@link FST.BytesReader} to read. It must be positioned at the beginning of the bit-table. + */ + static int countBits(int bitTableBytes, FST.BytesReader reader) throws IOException { + assert bitTableBytes >= 0 : "bitTableBytes=" + bitTableBytes; + int bitCount = 0; + for (int i = bitTableBytes >> 3; i > 0; i--) { + // Count the bits set for all plain longs. + bitCount += Long.bitCount(read8Bytes(reader)); + } + int numRemainingBytes; + if ((numRemainingBytes = bitTableBytes & (Long.BYTES - 1)) != 0) { + bitCount += Long.bitCount(readUpTo8Bytes(numRemainingBytes, reader)); + } + return bitCount; + } + + /** + * Counts the bits set up to the given bit zero-based index, exclusive. + *
In other words, how many 1s there are up to the bit at the given index excluded. + *
Example: bitIndex 10 means the third bit on the right of the second byte. + * + * @param bitIndex The bit zero-based index, exclusive. It must be greater than or equal to 0, and less than or equal + * to {@code number of bit-table bytes * Byte.SIZE}. + * @param reader The {@link FST.BytesReader} to read. It must be positioned at the beginning of the bit-table. + */ + static int countBitsUpTo(int bitIndex, FST.BytesReader reader) throws IOException { + assert bitIndex >= 0 : "bitIndex=" + bitIndex; + int bitCount = 0; + for (int i = bitIndex >> 6; i > 0; i--) { + // Count the bits set for all plain longs. + bitCount += Long.bitCount(read8Bytes(reader)); + } + int remainingBits; + if ((remainingBits = bitIndex & (Long.SIZE - 1)) != 0) { + int numRemainingBytes = (remainingBits + (Byte.SIZE - 1)) >> 3; + // Prepare a mask with 1s on the right up to bitIndex exclusive. + long mask = (1L << bitIndex) - 1L; // Shifts are mod 64. + // Count the bits set only within the mask part, so up to bitIndex exclusive. + bitCount += Long.bitCount(readUpTo8Bytes(numRemainingBytes, reader) & mask); + } + return bitCount; + } + + /** + * Returns the index of the next bit set following the given bit zero-based index. + *
For example with bits 100011: + * the next bit set after index=-1 is at index=0; + * the next bit set after index=0 is at index=1; + * the next bit set after index=1 is at index=5; + * there is no next bit set after index=5. + * + * @param bitIndex The bit zero-based index. It must be greater than or equal to -1, and strictly less than + * {@code number of bit-table bytes * Byte.SIZE}. + * @param bitTableBytes The number of bytes in the bit-table. + * @param reader The {@link FST.BytesReader} to read. It must be positioned at the beginning of the bit-table. + * @return The zero-based index of the next bit set after the provided {@code bitIndex}; or -1 if none. + */ + static int nextBitSet(int bitIndex, int bitTableBytes, FST.BytesReader reader) throws IOException { + assert bitIndex >= -1 && bitIndex < bitTableBytes * Byte.SIZE : "bitIndex=" + bitIndex + " bitTableBytes=" + bitTableBytes; + int byteIndex = bitIndex / Byte.SIZE; + int mask = -1 << ((bitIndex + 1) & (Byte.SIZE - 1)); + int i; + if (mask == -1 && bitIndex != -1) { + reader.skipBytes(byteIndex + 1); + i = 0; + } else { + reader.skipBytes(byteIndex); + i = (reader.readByte() & 0xFF) & mask; + } + while (i == 0) { + if (++byteIndex == bitTableBytes) { + return -1; + } + i = reader.readByte() & 0xFF; + } + return Integer.numberOfTrailingZeros(i) + (byteIndex << 3); + } + + /** + * Returns the index of the previous bit set preceding the given bit zero-based index. + *
For example with bits 100011: + * there is no previous bit set before index=0. + * the previous bit set before index=1 is at index=0; + * the previous bit set before index=5 is at index=1; + * the previous bit set before index=64 is at index=5; + * + * @param bitIndex The bit zero-based index. It must be greater than or equal to 0, and less than or equal to + * {@code number of bit-table bytes * Byte.SIZE}. + * @param reader The {@link FST.BytesReader} to read. It must be positioned at the beginning of the bit-table. + * @return The zero-based index of the previous bit set before the provided {@code bitIndex}; or -1 if none. + */ + static int previousBitSet(int bitIndex, FST.BytesReader reader) throws IOException { + assert bitIndex >= 0 : "bitIndex=" + bitIndex; + int byteIndex = bitIndex >> 3; + reader.skipBytes(byteIndex); + int mask = (1 << (bitIndex & (Byte.SIZE - 1))) - 1; + int i = (reader.readByte() & 0xFF) & mask; + while (i == 0) { + if (byteIndex-- == 0) { + return -1; + } + reader.skipBytes(-2); // FST.BytesReader implementations support negative skip. + i = reader.readByte() & 0xFF; + } + return (Integer.SIZE - 1) - Integer.numberOfLeadingZeros(i) + (byteIndex << 3); + } + + private static long readByte(FST.BytesReader reader) throws IOException { + return reader.readByte() & 0xFFL; + } + + private static long readUpTo8Bytes(int numBytes, FST.BytesReader reader) throws IOException { + assert numBytes > 0 && numBytes <= 8 : "numBytes=" + numBytes; + long l = readByte(reader); + int shift = 0; + while (--numBytes != 0) { + l |= readByte(reader) << (shift += 8); + } + return l; + } + + private static long read8Bytes(FST.BytesReader reader) throws IOException { + return readByte(reader) + | readByte(reader) << 8 + | readByte(reader) << 16 + | readByte(reader) << 24 + | readByte(reader) << 32 + | readByte(reader) << 40 + | readByte(reader) << 48 + | readByte(reader) << 56; + } +} \ No newline at end of file diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/FST.java b/lucene/core/src/java/org/apache/lucene/util/fst/FST.java index 571c1e5e23f..dca727572d2 100644 --- a/lucene/core/src/java/org/apache/lucene/util/fst/FST.java +++ b/lucene/core/src/java/org/apache/lucene/util/fst/FST.java @@ -33,10 +33,11 @@ import org.apache.lucene.store.InputStreamDataInput; import org.apache.lucene.store.OutputStreamDataOutput; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.Constants; import org.apache.lucene.util.RamUsageEstimator; +import static org.apache.lucene.util.fst.FST.Arc.BitTable; + // TODO: break this into WritableFST and ReadOnlyFST.. then // we can have subclasses of ReadOnlyFST to handle the // different byte[] level encodings (packed or @@ -164,23 +165,36 @@ public final class FST implements Accountable { private long nextArc; - private int arcIdx; + private byte nodeFlags; //*** Fields for arcs belonging to a node with fixed length arcs. // So only valid when bytesPerArc != 0. - - private byte nodeFlags; - - private long posArcsStart; + // nodeFlags == ARCS_FOR_BINARY_SEARCH || nodeFlags == ARCS_FOR_DIRECT_ADDRESSING. private int bytesPerArc; + private long posArcsStart; + + private int arcIdx; + private int numArcs; - private BitTable bitTable; + //*** Fields for a direct addressing node. nodeFlags == ARCS_FOR_DIRECT_ADDRESSING. + /** Start position in the {@link FST.BytesReader} of the presence bits for a direct addressing node, aka the bit-table */ + private long bitTableStart; + + /** First label of a direct addressing node. */ private int firstLabel; + /** + * Index of the current label of a direct addressing node. While {@link #arcIdx} is the current index in the label + * range, {@link #presenceIndex} is its corresponding index in the list of actually present labels. It is equal + * to the number of bits set before the bit at {@link #arcIdx} in the bit-table. This field is a cache to avoid + * to count bits set repeatedly when iterating the next arcs. + */ + private int presenceIndex; + /** Returns this */ public Arc copyFrom(Arc other) { label = other.label(); @@ -191,15 +205,18 @@ public final class FST implements Accountable { nextArc = other.nextArc(); nodeFlags = other.nodeFlags(); bytesPerArc = other.bytesPerArc(); - if (bytesPerArc() != 0) { - posArcsStart = other.posArcsStart(); - arcIdx = other.arcIdx(); - numArcs = other.numArcs(); - if (nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING) { - bitTable = other.bitTable() == null ? null : other.bitTable().copy(); - firstLabel = other.firstLabel(); - } - } + + // Fields for arcs belonging to a node with fixed length arcs. + // We could avoid copying them if bytesPerArc() == 0 (this was the case with previous code, and the current code + // still supports that), but it may actually help external uses of FST to have consistent arc state, and debugging + // is easier. + posArcsStart = other.posArcsStart(); + arcIdx = other.arcIdx(); + numArcs = other.numArcs(); + bitTableStart = other.bitTableStart; + firstLabel = other.firstLabel(); + presenceIndex = other.presenceIndex; + return this; } @@ -239,7 +256,8 @@ public final class FST implements Accountable { b.append(" nextFinalOutput=").append(nextFinalOutput()); } if (bytesPerArc() != 0) { - b.append(" arcArray(idx=").append(arcIdx()).append(" of ").append(numArcs()).append(")"); + b.append(" arcArray(idx=").append(arcIdx()).append(" of ").append(numArcs()).append(")") + .append("(").append(nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING ? "da" : "bs").append(")"); } return b.toString(); } @@ -303,21 +321,6 @@ public final class FST implements Accountable { return numArcs; } - /** Table of bits of a direct addressing node. - * Only valid if nodeFlags == {@link #ARCS_FOR_DIRECT_ADDRESSING}; - * may be null otherwise. */ - BitTable bitTable() { - return bitTable; - } - - /** The table of bits of a direct addressing node created lazily. */ - BitTable getOrCreateBitTable() { - if (bitTable == null) { - bitTable = new BitTable(); - } - return bitTable; - } - /** First label of a direct addressing node. * Only valid if nodeFlags == {@link #ARCS_FOR_DIRECT_ADDRESSING}. */ int firstLabel() { @@ -325,65 +328,63 @@ public final class FST implements Accountable { } /** - * Reusable table of bits using an array of long internally. + * Helper methods to read the bit-table of a direct addressing node. + * Only valid for {@link Arc} with {@link Arc#nodeFlags()} == {@code ARCS_FOR_DIRECT_ADDRESSING}. */ static class BitTable { - private long[] bits; - private int numLongs; - - /** Sets the number of longs in the internal long array. - * Enlarges it if needed. Always clears the array. */ - BitTable setNumLongs(int numLongs) { - assert numLongs >= 0; - this.numLongs = numLongs; - if (bits == null || bits.length < numLongs) { - bits = new long[ArrayUtil.oversize(numLongs, Long.BYTES)]; - } else { - for (int i = 0; i < numLongs; i++) { - bits[i] = 0L; - } - } - return this; + /** See {@link BitTableUtil#isBitSet(int, FST.BytesReader)}. */ + static boolean isBitSet(int bitIndex, Arc arc, FST.BytesReader in) throws IOException { + assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING; + in.setPosition(arc.bitTableStart); + return BitTableUtil.isBitSet(bitIndex, in); } - /** Creates a new {@link BitTable} by copying this one. */ - BitTable copy() { - BitTable bitTable = new BitTable(); - bitTable.bits = ArrayUtil.copyOfSubArray(bits, 0, bits.length); - bitTable.numLongs = numLongs; - return bitTable; + /** + * See {@link BitTableUtil#countBits(int, FST.BytesReader)}. + * The count of bit set is the number of arcs of a direct addressing node. + */ + static int countBits(Arc arc, FST.BytesReader in) throws IOException { + assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING; + in.setPosition(arc.bitTableStart); + return BitTableUtil.countBits(getNumPresenceBytes(arc.numArcs()), in); } - boolean assertIsValid() { - assert numLongs > 0 && numLongs <= bits.length; + /** See {@link BitTableUtil#countBitsUpTo(int, FST.BytesReader)}. */ + static int countBitsUpTo(int bitIndex, Arc arc, FST.BytesReader in) throws IOException { + assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING; + in.setPosition(arc.bitTableStart); + return BitTableUtil.countBitsUpTo(bitIndex, in); + } + + /** See {@link BitTableUtil#nextBitSet(int, int, FST.BytesReader)}. */ + static int nextBitSet(int bitIndex, Arc arc, FST.BytesReader in) throws IOException { + assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING; + in.setPosition(arc.bitTableStart); + return BitTableUtil.nextBitSet(bitIndex, getNumPresenceBytes(arc.numArcs()), in); + } + + /** See {@link BitTableUtil#previousBitSet(int, FST.BytesReader)}. */ + static int previousBitSet(int bitIndex, Arc arc, FST.BytesReader in) throws IOException { + assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING; + in.setPosition(arc.bitTableStart); + return BitTableUtil.previousBitSet(bitIndex, in); + } + + /** + * Asserts the bit-table of the provided {@link Arc} is valid. + */ + static boolean assertIsValid(Arc arc, FST.BytesReader in) throws IOException { + assert arc.bytesPerArc() > 0; + assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING; + // First bit must be set. + assert isBitSet(0, arc, in); + // Last bit must be set. + assert isBitSet(arc.numArcs() - 1, arc, in); + // No bit set after the last arc. + assert nextBitSet(arc.numArcs() - 1, arc, in) == -1; return true; } - - /** Forwards to {@link BitUtil#isBitSet(long[], int, int)}. */ - boolean isBitSet(int bitIndex) { - return BitUtil.isBitSet(bits, numLongs, bitIndex); - } - - /** Forwards to {@link BitUtil#countBits(long[], int)}. */ - int countBits() { - return BitUtil.countBits(bits, numLongs); - } - - /** Forwards to {@link BitUtil#countBitsUpTo(long[], int, int)}. */ - int countBitsUpTo(int bitIndex) { - return BitUtil.countBitsUpTo(bits, numLongs, bitIndex); - } - - /** Forwards to {@link BitUtil#nextBitSet(long[], int, int)}. */ - int nextBitSet(int bitIndex) { - return BitUtil.nextBitSet(bits, numLongs, bitIndex); - } - - /** Forwards to {@link BitUtil#previousBitSet(long[], int, int)}. */ - int previousBitSet(int bitIndex) { - return BitUtil.previousBitSet(bits, numLongs, bitIndex); - } } } @@ -921,41 +922,19 @@ public final class FST implements Accountable { /** Gets the number of bytes required to flag the presence of each arc in the given label range, one bit per arc. */ private static int getNumPresenceBytes(int labelRange) { - return (labelRange + 7) / Byte.SIZE; + assert labelRange >= 0; + return (labelRange + 7) >> 3; } /** - * Reads the presence bits of a direct-addressing node, store them in the provided arc {@link Arc#bitTable()} - * and returns the number of presence bytes. + * Reads the presence bits of a direct-addressing node. + * Actually we don't read them here, we just keep the pointer to the bit-table start and we skip them. */ - private int readPresenceBytes(Arc arc, BytesReader in) throws IOException { - int numPresenceBytes = getNumPresenceBytes(arc.numArcs()); - Arc.BitTable presenceBits = arc.getOrCreateBitTable().setNumLongs((numPresenceBytes + 7) / Long.BYTES); - for (int i = 0; i < numPresenceBytes; i++) { - // Read the next unsigned byte, shift it to the left, and appends it to the current long. - presenceBits.bits[i / Long.BYTES] |= (in.readByte() & 0xFFL) << (i * Byte.SIZE); - } - assert assertPresenceBytesAreValid(arc); - return numPresenceBytes; - } - - private int getNumArcsDirectAddressing(Arc arc) { + private void readPresenceBytes(Arc arc, BytesReader in) throws IOException { + assert arc.bytesPerArc() > 0; assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING; - return arc.bitTable().countBits(); - } - - private boolean assertPresenceBytesAreValid(Arc arc) { - assert arc.bitTable() != null; - assert arc.bitTable().assertIsValid(); - // First bit must be set. - assert arc.bitTable().isBitSet(0); - // Last bit must be set. - assert arc.bitTable().isBitSet(arc.numArcs() - 1); - // No bit set after the last arc. - assert arc.bitTable().nextBitSet(arc.numArcs() - 1) == -1; - // Total bit set (real num arcs) must be <= label range (stored in arc.numArcs()). - assert getNumArcsDirectAddressing(arc) <= arc.numArcs(); - return true; + arc.bitTableStart = in.getPosition(); + in.skipBytes(getNumPresenceBytes(arc.numArcs())); } /** Fills virtual 'start' arc, ie, an empty incoming arc to the FST's start node */ @@ -1010,7 +989,7 @@ public final class FST implements Accountable { readPresenceBytes(arc, in); arc.firstLabel = readLabel(in); arc.posArcsStart = in.getPosition(); - readArcByDirectAddressing(arc, in, arc.numArcs() - 1); + readLastArcByDirectAddressing(arc, in); } else { arc.arcIdx = arc.numArcs() - 2; arc.posArcsStart = in.getPosition(); @@ -1095,6 +1074,7 @@ public final class FST implements Accountable { if (flags == ARCS_FOR_DIRECT_ADDRESSING) { readPresenceBytes(arc, in); arc.firstLabel = readLabel(in); + arc.presenceIndex = -1; } arc.posArcsStart = in.getPosition(); //System.out.println(" bytesPer=" + arc.bytesPerArc + " numArcs=" + arc.numArcs + " arcsStart=" + pos); @@ -1166,9 +1146,9 @@ public final class FST implements Accountable { assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING; // Direct addressing node. The label is not stored but rather inferred // based on first label and arc index in the range. - assert assertPresenceBytesAreValid(arc); - assert arc.bitTable().isBitSet(arc.arcIdx()); - int nextIndex = arc.bitTable().nextBitSet(arc.arcIdx()); + assert BitTable.assertIsValid(arc, in); + assert BitTable.isBitSet(arc.arcIdx(), arc, in); + int nextIndex = BitTable.nextBitSet(arc.arcIdx(), arc, in); assert nextIndex != -1; return arc.firstLabel() + nextIndex; } @@ -1183,6 +1163,8 @@ public final class FST implements Accountable { } public Arc readArcByIndex(Arc arc, final BytesReader in, int idx) throws IOException { + assert arc.bytesPerArc() > 0; + assert arc.nodeFlags() == ARCS_FOR_BINARY_SEARCH; assert idx >= 0 && idx < arc.numArcs(); in.setPosition(arc.posArcsStart() - idx * arc.bytesPerArc()); arc.arcIdx = idx; @@ -1190,25 +1172,44 @@ public final class FST implements Accountable { return readArc(arc, in); } - /** Reads a present direct addressing node arc, with the provided index in the label range. + /** + * Reads a present direct addressing node arc, with the provided index in the label range. * * @param rangeIndex The index of the arc in the label range. It must be present. * The real arc offset is computed based on the presence bits of * the direct addressing node. */ public Arc readArcByDirectAddressing(Arc arc, final BytesReader in, int rangeIndex) throws IOException { - assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING; - assert arc.bytesPerArc() > 0; - assert assertPresenceBytesAreValid(arc); + assert BitTable.assertIsValid(arc, in); assert rangeIndex >= 0 && rangeIndex < arc.numArcs(); - assert arc.bitTable().isBitSet(rangeIndex); - int presenceIndex = arc.bitTable().countBitsUpTo(rangeIndex); + assert BitTable.isBitSet(rangeIndex, arc, in); + int presenceIndex = BitTable.countBitsUpTo(rangeIndex, arc, in); + return readArcByDirectAddressing(arc, in, rangeIndex, presenceIndex); + } + + /** + * Reads a present direct addressing node arc, with the provided index in the label range and its corresponding + * presence index (which is the count of presence bits before it). + */ + private Arc readArcByDirectAddressing(Arc arc, final BytesReader in, int rangeIndex, int presenceIndex) throws IOException { in.setPosition(arc.posArcsStart() - presenceIndex * arc.bytesPerArc()); arc.arcIdx = rangeIndex; + arc.presenceIndex = presenceIndex; arc.flags = in.readByte(); return readArc(arc, in); } + /** + * Reads the last arc of a direct addressing node. + * This method is equivalent to call {@link #readArcByDirectAddressing(Arc, BytesReader, int)} with {@code rangeIndex} + * equal to {@code arc.numArcs() - 1}, but it is faster. + */ + public Arc readLastArcByDirectAddressing(Arc arc, final BytesReader in) throws IOException { + assert BitTable.assertIsValid(arc, in); + int presenceIndex = BitTable.countBits(arc, in) - 1; + return readArcByDirectAddressing(arc, in, arc.numArcs() - 1, presenceIndex); + } + /** Never returns null, but you should never call this if * arc.isLast() is true. */ public Arc readNextRealArc(Arc arc, final BytesReader in) throws IOException { @@ -1227,11 +1228,10 @@ public final class FST implements Accountable { break; case ARCS_FOR_DIRECT_ADDRESSING: - assert arc.bytesPerArc() > 0; - assert assertPresenceBytesAreValid(arc); - assert arc.arcIdx() == -1 || arc.bitTable().isBitSet(arc.arcIdx()); - int nextIndex = arc.bitTable().nextBitSet(arc.arcIdx()); - return readArcByDirectAddressing(arc, in, nextIndex); + assert BitTable.assertIsValid(arc, in); + assert arc.arcIdx() == -1 || BitTable.isBitSet(arc.arcIdx(), arc, in); + int nextIndex = BitTable.nextBitSet(arc.arcIdx(), arc, in); + return readArcByDirectAddressing(arc, in, nextIndex, arc.presenceIndex + 1); default: // Variable length arcs - linear search. @@ -1282,7 +1282,7 @@ public final class FST implements Accountable { // must scan seekToNextNode(in); } else { - int numArcs = arc.nodeFlags == ARCS_FOR_DIRECT_ADDRESSING ? getNumArcsDirectAddressing(arc) : arc.numArcs(); + int numArcs = arc.nodeFlags == ARCS_FOR_DIRECT_ADDRESSING ? BitTable.countBits(arc, in) : arc.numArcs(); in.setPosition(arc.posArcsStart() - arc.bytesPerArc() * numArcs); } } @@ -1355,7 +1355,7 @@ public final class FST implements Accountable { int arcIndex = labelToMatch - arc.firstLabel(); if (arcIndex < 0 || arcIndex >= arc.numArcs()) { return null; // Before or after label range. - } else if (!arc.bitTable().isBitSet(arcIndex)) { + } else if (!BitTable.isBitSet(arcIndex, arc, in)) { return null; // Arc missing in the range. } return readArcByDirectAddressing(arc, in, arcIndex); @@ -1455,113 +1455,4 @@ public final class FST implements Accountable { * under-the-hood. */ public abstract boolean reversed(); } - - /* - public void countSingleChains() throws IOException { - // TODO: must assert this FST was built with - // "willRewrite" - - final List> queue = new ArrayList<>(); - - // TODO: use bitset to not revisit nodes already - // visited - - FixedBitSet seen = new FixedBitSet(1+nodeCount); - int saved = 0; - - queue.add(new ArcAndState(getFirstArc(new Arc()), new IntsRef())); - Arc scratchArc = new Arc<>(); - while(queue.size() > 0) { - //System.out.println("cycle size=" + queue.size()); - //for(ArcAndState ent : queue) { - // System.out.println(" " + Util.toBytesRef(ent.chain, new BytesRef())); - // } - final ArcAndState arcAndState = queue.get(queue.size()-1); - seen.set(arcAndState.arc.node); - final BytesRef br = Util.toBytesRef(arcAndState.chain, new BytesRef()); - if (br.length > 0 && br.bytes[br.length-1] == -1) { - br.length--; - } - //System.out.println(" top node=" + arcAndState.arc.target + " chain=" + br.utf8ToString()); - if (targetHasArcs(arcAndState.arc) && !seen.get(arcAndState.arc.target)) { - // push - readFirstTargetArc(arcAndState.arc, scratchArc); - //System.out.println(" push label=" + (char) scratchArc.label); - //System.out.println(" tonode=" + scratchArc.target + " last?=" + scratchArc.isLast()); - - final IntsRef chain = IntsRef.deepCopyOf(arcAndState.chain); - chain.grow(1+chain.length); - // TODO - //assert scratchArc.label != END_LABEL; - chain.ints[chain.length] = scratchArc.label; - chain.length++; - - if (scratchArc.isLast()) { - if (scratchArc.target != -1 && inCounts[scratchArc.target] == 1) { - //System.out.println(" append"); - } else { - if (arcAndState.chain.length > 1) { - saved += chain.length-2; - try { - System.out.println("chain: " + Util.toBytesRef(chain, new BytesRef()).utf8ToString()); - } catch (AssertionError ae) { - System.out.println("chain: " + Util.toBytesRef(chain, new BytesRef())); - } - } - chain.length = 0; - } - } else { - //System.out.println(" reset"); - if (arcAndState.chain.length > 1) { - saved += arcAndState.chain.length-2; - try { - System.out.println("chain: " + Util.toBytesRef(arcAndState.chain, new BytesRef()).utf8ToString()); - } catch (AssertionError ae) { - System.out.println("chain: " + Util.toBytesRef(arcAndState.chain, new BytesRef())); - } - } - if (scratchArc.target != -1 && inCounts[scratchArc.target] != 1) { - chain.length = 0; - } else { - chain.ints[0] = scratchArc.label; - chain.length = 1; - } - } - // TODO: instead of new Arc() we can re-use from - // a by-depth array - queue.add(new ArcAndState(new Arc().copyFrom(scratchArc), chain)); - } else if (!arcAndState.arc.isLast()) { - // next - readNextArc(arcAndState.arc); - //System.out.println(" next label=" + (char) arcAndState.arc.label + " len=" + arcAndState.chain.length); - if (arcAndState.chain.length != 0) { - arcAndState.chain.ints[arcAndState.chain.length-1] = arcAndState.arc.label; - } - } else { - if (arcAndState.chain.length > 1) { - saved += arcAndState.chain.length-2; - System.out.println("chain: " + Util.toBytesRef(arcAndState.chain, new BytesRef()).utf8ToString()); - } - // pop - //System.out.println(" pop"); - queue.remove(queue.size()-1); - while(queue.size() > 0 && queue.get(queue.size()-1).arc.isLast()) { - queue.remove(queue.size()-1); - } - if (queue.size() > 0) { - final ArcAndState arcAndState2 = queue.get(queue.size()-1); - readNextArc(arcAndState2.arc); - //System.out.println(" read next=" + (char) arcAndState2.arc.label + " queue=" + queue.size()); - assert arcAndState2.arc.label != END_LABEL; - if (arcAndState2.chain.length != 0) { - arcAndState2.chain.ints[arcAndState2.chain.length-1] = arcAndState2.arc.label; - } - } - } - } - - System.out.println("TOT saved " + saved); - } - */ - } diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/FSTEnum.java b/lucene/core/src/java/org/apache/lucene/util/fst/FSTEnum.java index fb821942e59..e584acd6b04 100644 --- a/lucene/core/src/java/org/apache/lucene/util/fst/FSTEnum.java +++ b/lucene/core/src/java/org/apache/lucene/util/fst/FSTEnum.java @@ -22,6 +22,8 @@ import java.io.IOException; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.RamUsageEstimator; +import static org.apache.lucene.util.fst.FST.Arc.BitTable; + /** Can next() and advance() through the terms in an FST * * @lucene.experimental @@ -36,7 +38,6 @@ abstract class FSTEnum { protected final T NO_OUTPUT; protected final FST.BytesReader fstReader; - protected final FST.Arc scratchArc = new FST.Arc<>(); protected int upto; int targetLength; @@ -178,7 +179,7 @@ abstract class FSTEnum { } else { if (targetIndex < 0) { targetIndex = -1; - } else if (arc.bitTable().isBitSet(targetIndex)) { + } else if (BitTable.isBitSet(targetIndex, arc, in)) { fst.readArcByDirectAddressing(arc, in, targetIndex); assert arc.label() == targetLabel; // found -- copy pasta from below @@ -191,7 +192,7 @@ abstract class FSTEnum { return fst.readFirstTargetArc(arc, getArc(upto), fstReader); } // Not found, return the next arc (ceil). - int ceilIndex = arc.bitTable().nextBitSet(targetIndex); + int ceilIndex = BitTable.nextBitSet(targetIndex, arc, in); assert ceilIndex != -1; fst.readArcByDirectAddressing(arc, in, ceilIndex); assert arc.label() > targetLabel; @@ -335,14 +336,14 @@ abstract class FSTEnum { return backtrackToFloorArc(arc, targetLabel, in); } else if (targetIndex >= arc.numArcs()) { // After last arc. - fst.readArcByDirectAddressing(arc, in, arc.numArcs() - 1); + fst.readLastArcByDirectAddressing(arc, in); assert arc.label() < targetLabel; assert arc.isLast(); pushLast(); return null; } else { // Within label range. - if (arc.bitTable().isBitSet(targetIndex)) { + if (BitTable.isBitSet(targetIndex, arc, in)) { fst.readArcByDirectAddressing(arc, in, targetIndex); assert arc.label() == targetLabel; // found -- copy pasta from below @@ -355,7 +356,7 @@ abstract class FSTEnum { return fst.readFirstTargetArc(arc, getArc(upto), fstReader); } // Scan backwards to find a floor arc. - int floorIndex = arc.bitTable().previousBitSet(targetIndex); + int floorIndex = BitTable.previousBitSet(targetIndex, arc, in); assert floorIndex != -1; fst.readArcByDirectAddressing(arc, in, floorIndex); assert arc.label() < targetLabel; @@ -421,10 +422,10 @@ abstract class FSTEnum { assert targetIndex >= 0; if (targetIndex >= arc.numArcs()) { // Beyond last arc. Take last arc. - fst.readArcByDirectAddressing(arc, in, arc.numArcs() - 1); + fst.readLastArcByDirectAddressing(arc, in); } else { // Take the preceding arc, even if the target is present. - int floorIndex = arc.bitTable().previousBitSet(targetIndex); + int floorIndex = BitTable.previousBitSet(targetIndex, arc, in); if (floorIndex > 0) { fst.readArcByDirectAddressing(arc, in, floorIndex); } diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/NodeHash.java b/lucene/core/src/java/org/apache/lucene/util/fst/NodeHash.java index 9fcf5f5db56..8572f54507b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/fst/NodeHash.java +++ b/lucene/core/src/java/org/apache/lucene/util/fst/NodeHash.java @@ -51,7 +51,7 @@ final class NodeHash { } else { assert scratchArc.nodeFlags() == FST.ARCS_FOR_DIRECT_ADDRESSING; if ((node.arcs[node.numArcs - 1].label - node.arcs[0].label + 1) != scratchArc.numArcs() - || node.numArcs != scratchArc.bitTable().countBits()) { + || node.numArcs != FST.Arc.BitTable.countBits(scratchArc, in)) { return false; } } diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/Util.java b/lucene/core/src/java/org/apache/lucene/util/fst/Util.java index 5f822c48811..e06011db3d8 100644 --- a/lucene/core/src/java/org/apache/lucene/util/fst/Util.java +++ b/lucene/core/src/java/org/apache/lucene/util/fst/Util.java @@ -33,6 +33,8 @@ import java.util.Iterator; import java.util.List; import java.util.TreeSet; +import static org.apache.lucene.util.fst.FST.Arc.BitTable; + /** Static helper methods. * * @lucene.experimental */ @@ -478,6 +480,7 @@ public final class Util { // For each arc leaving this node: boolean foundZero = false; + boolean arcCopyIsPending = false; while(true) { // tricky: instead of comparing output == 0, we must // express it via the comparator compare(output, 0) == 0 @@ -486,7 +489,7 @@ public final class Util { foundZero = true; break; } else if (!foundZero) { - scratchArc.copyFrom(path.arc); + arcCopyIsPending = true; foundZero = true; } else { addIfCompetitive(path); @@ -497,16 +500,16 @@ public final class Util { if (path.arc.isLast()) { break; } + if (arcCopyIsPending) { + scratchArc.copyFrom(path.arc); + arcCopyIsPending = false; + } fst.readNextArc(path.arc, fstReader); } assert foundZero; - if (queue != null) { - // TODO: maybe we can save this copyFrom if we - // are more clever above... eg on finding the - // first NO_OUTPUT arc we'd switch to using - // scratchArc + if (queue != null && !arcCopyIsPending) { path.arc.copyFrom(scratchArc); } @@ -948,11 +951,11 @@ public final class Util { } else if (targetIndex < 0) { return arc; } else { - if (arc.bitTable().isBitSet(targetIndex)) { + if (BitTable.isBitSet(targetIndex, arc, in)) { fst.readArcByDirectAddressing(arc, in, targetIndex); assert arc.label() == label; } else { - int ceilIndex = arc.bitTable().nextBitSet(targetIndex); + int ceilIndex = BitTable.nextBitSet(targetIndex, arc, in); assert ceilIndex != -1; fst.readArcByDirectAddressing(arc, in, ceilIndex); assert arc.label() > label; diff --git a/lucene/core/src/test/org/apache/lucene/util/TestBitUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestBitUtil.java deleted file mode 100644 index b31a6fc4b67..00000000000 --- a/lucene/core/src/test/org/apache/lucene/util/TestBitUtil.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.util; - -public class TestBitUtil extends LuceneTestCase { - - public void testNextBitSet() { - int numIterations = atLeast(1000); - for (int i = 0; i < numIterations; i++) { - long[] bits = buildRandomBits(); - int numLong = bits.length - 1; - - // Verify nextBitSet with countBitsUpTo for all bit indexes. - for (int bitIndex = -1; bitIndex < 64 * numLong; bitIndex++) { - int nextIndex = BitUtil.nextBitSet(bits, numLong, bitIndex); - if (nextIndex == -1) { - assertEquals("No next bit set, so expected no bit count diff" - + " (i=" + i + " bitIndex=" + bitIndex + ")", - BitUtil.countBitsUpTo(bits, numLong, bitIndex + 1), BitUtil.countBits(bits, numLong)); - } else { - assertTrue("Expected next bit set at nextIndex=" + nextIndex - + " (i=" + i + " bitIndex=" + bitIndex + ")", - BitUtil.isBitSet(bits, numLong, nextIndex)); - assertEquals("Next bit set at nextIndex=" + nextIndex - + " so expected bit count diff of 1" - + " (i=" + i + " bitIndex=" + bitIndex + ")", - BitUtil.countBitsUpTo(bits, numLong, bitIndex + 1) + 1, - BitUtil.countBitsUpTo(bits, numLong, nextIndex + 1)); - } - } - } - } - - public void testPreviousBitSet() { - int numIterations = atLeast(1000); - for (int i = 0; i < numIterations; i++) { - long[] bits = buildRandomBits(); - int numLong = bits.length - 1; - - // Verify previousBitSet with countBitsUpTo for all bit indexes. - for (int bitIndex = 0; bitIndex <= 64 * numLong; bitIndex++) { - int previousIndex = BitUtil.previousBitSet(bits, numLong, bitIndex); - if (previousIndex == -1) { - assertEquals("No previous bit set, so expected bit count 0" - + " (i=" + i + " bitIndex=" + bitIndex + ")", - 0, BitUtil.countBitsUpTo(bits, numLong, bitIndex)); - } else { - assertTrue("Expected previous bit set at previousIndex=" + previousIndex - + " (i=" + i + " bitIndex=" + bitIndex + ")", - BitUtil.isBitSet(bits, numLong, previousIndex)); - int bitCount = BitUtil.countBitsUpTo(bits, numLong, Math.min(bitIndex + 1, numLong * Long.SIZE)); - int expectedPreviousBitCount = bitIndex < numLong * Long.SIZE && BitUtil.isBitSet(bits, numLong, bitIndex) ? - bitCount - 1 : bitCount; - assertEquals("Previous bit set at previousIndex=" + previousIndex - + " with current bitCount=" + bitCount - + " so expected previousBitCount=" + expectedPreviousBitCount - + " (i=" + i + " bitIndex=" + bitIndex + ")", - expectedPreviousBitCount, BitUtil.countBitsUpTo(bits, numLong, previousIndex + 1)); - } - } - } - } - - private long[] buildRandomBits() { - long[] bits = new long[random().nextInt(3) + 2]; - for (int j = 0; j < bits.length; j++) { - // Bias towards zeros which require special logic. - bits[j] = random().nextInt(4) == 0 ? 0L : random().nextLong(); - } - return bits; - } -} diff --git a/lucene/core/src/test/org/apache/lucene/util/fst/TestBitTableUtil.java b/lucene/core/src/test/org/apache/lucene/util/fst/TestBitTableUtil.java new file mode 100644 index 00000000000..57a27db1b96 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/fst/TestBitTableUtil.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.fst; + +import java.io.IOException; + +import org.apache.lucene.util.LuceneTestCase; + +public class TestBitTableUtil extends LuceneTestCase { + + public void testNextBitSet() throws IOException { + int numIterations = atLeast(1000); + for (int i = 0; i < numIterations; i++) { + byte[] bits = buildRandomBits(); + int numBytes = bits.length - 1; + int numBits = numBytes * Byte.SIZE; + + // Verify nextBitSet with countBitsUpTo for all bit indexes. + for (int bitIndex = -1; bitIndex < numBits; bitIndex++) { + int nextIndex = BitTableUtil.nextBitSet(bitIndex, numBytes, reader(bits)); + if (nextIndex == -1) { + assertEquals("No next bit set, so expected no bit count diff" + + " (i=" + i + " bitIndex=" + bitIndex + ")", + BitTableUtil.countBitsUpTo(bitIndex + 1, reader(bits)), + BitTableUtil.countBits(numBytes, reader(bits))); + } else { + assertTrue("Expected next bit set at nextIndex=" + nextIndex + + " (i=" + i + " bitIndex=" + bitIndex + ")", + BitTableUtil.isBitSet(nextIndex, reader(bits))); + assertEquals("Next bit set at nextIndex=" + nextIndex + + " so expected bit count diff of 1" + + " (i=" + i + " bitIndex=" + bitIndex + ")", + BitTableUtil.countBitsUpTo(bitIndex + 1, reader(bits)) + 1, + BitTableUtil.countBitsUpTo(nextIndex + 1, reader(bits))); + } + } + } + } + + public void testPreviousBitSet() throws IOException { + int numIterations = atLeast(1000); + for (int i = 0; i < numIterations; i++) { + byte[] bits = buildRandomBits(); + int numBytes = bits.length - 1; + int numBits = numBytes * Byte.SIZE; + + // Verify previousBitSet with countBitsUpTo for all bit indexes. + for (int bitIndex = 0; bitIndex <= numBits; bitIndex++) { + int previousIndex = BitTableUtil.previousBitSet(bitIndex, reader(bits)); + if (previousIndex == -1) { + assertEquals("No previous bit set, so expected bit count 0" + + " (i=" + i + " bitIndex=" + bitIndex + ")", + 0, BitTableUtil.countBitsUpTo(bitIndex, reader(bits))); + } else { + assertTrue("Expected previous bit set at previousIndex=" + previousIndex + + " (i=" + i + " bitIndex=" + bitIndex + ")", + BitTableUtil.isBitSet(previousIndex, reader(bits))); + int bitCount = BitTableUtil.countBitsUpTo(Math.min(bitIndex + 1, numBits), reader(bits)); + int expectedPreviousBitCount = bitIndex < numBits && BitTableUtil.isBitSet(bitIndex, reader(bits)) ? + bitCount - 1 : bitCount; + assertEquals("Previous bit set at previousIndex=" + previousIndex + + " with current bitCount=" + bitCount + + " so expected previousBitCount=" + expectedPreviousBitCount + + " (i=" + i + " bitIndex=" + bitIndex + ")", + expectedPreviousBitCount, BitTableUtil.countBitsUpTo(previousIndex + 1, reader(bits))); + } + } + } + } + + private byte[] buildRandomBits() { + byte[] bits = new byte[random().nextInt(24) + 2]; + for (int i = 0; i < bits.length; i++) { + // Bias towards zeros which require special logic. + bits[i] = random().nextInt(4) == 0 ? 0 : (byte) random().nextInt(); + } + return bits; + } + + private static FST.BytesReader reader(byte[] bits) { + return new ByteArrayBytesReader(bits); + } + + private static class ByteArrayBytesReader extends FST.BytesReader { + + private final byte[] bits; + private int position; + + ByteArrayBytesReader(byte[] bits) { + this.bits = bits; + } + + @Override + public long getPosition() { + return position; + } + + @Override + public void setPosition(long pos) { + position = (int) pos; + } + + @Override + public boolean reversed() { + return false; + } + + @Override + public byte readByte() { + return bits[position++]; + } + + @Override + public void readBytes(byte[] b, int offset, int len) { + throw new UnsupportedOperationException(); + } + + @Override + public void skipBytes(long numBytes) { + position += numBytes; + } + } +} \ No newline at end of file diff --git a/lucene/core/src/test/org/apache/lucene/util/fst/TestFSTDirectAddressing.java b/lucene/core/src/test/org/apache/lucene/util/fst/TestFSTDirectAddressing.java index c3ea01f184c..25ea6f6fe0e 100644 --- a/lucene/core/src/test/org/apache/lucene/util/fst/TestFSTDirectAddressing.java +++ b/lucene/core/src/test/org/apache/lucene/util/fst/TestFSTDirectAddressing.java @@ -18,8 +18,10 @@ package org.apache.lucene.util.fst; import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; @@ -28,10 +30,13 @@ import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Set; +import java.util.zip.GZIPInputStream; import org.apache.lucene.store.ByteArrayDataInput; import org.apache.lucene.store.DataInput; +import org.apache.lucene.store.InputStreamDataInput; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.CharsRef; import org.apache.lucene.util.IntsRefBuilder; import org.apache.lucene.util.LuceneTestCase; @@ -151,17 +156,23 @@ public class TestFSTDirectAddressing extends LuceneTestCase { if (args.length < 2) { throw new IllegalArgumentException("Missing argument"); } - if (args[0].equals("-countFSTArcs")) { - countFSTArcs(args[1]); - } else if (args[0].equals("-measureFSTOversizing")) { - measureFSTOversizing(args[1]); - } else { - throw new IllegalArgumentException("Invalid argument " + args[0]); + switch (args[0]) { + case "-countFSTArcs": + countFSTArcs(args[1]); + break; + case "-measureFSTOversizing": + measureFSTOversizing(args[1]); + break; + case "-recompileAndWalk": + recompileAndWalk(args[1]); + break; + default: + throw new IllegalArgumentException("Invalid argument " + args[0]); } } - private static void countFSTArcs(String FSTFilePath) throws IOException { - byte[] buf = Files.readAllBytes(Paths.get(FSTFilePath)); + private static void countFSTArcs(String fstFilePath) throws IOException { + byte[] buf = Files.readAllBytes(Paths.get(fstFilePath)); DataInput in = new ByteArrayDataInput(buf); FST fst = new FST<>(in, ByteSequenceOutputs.getSingleton()); BytesRefFSTEnum fstEnum = new BytesRefFSTEnum<>(fst); @@ -211,4 +222,62 @@ public class TestFSTDirectAddressing extends LuceneTestCase { printStats(fstCompiler, ramBytesUsed, directAddressingMemoryIncreasePercent); } + + private static void recompileAndWalk(String fstFilePath) throws IOException { + try (InputStreamDataInput in = new InputStreamDataInput(newInputStream(Paths.get(fstFilePath)))) { + + System.out.println("Reading FST"); + long startTimeMs = System.currentTimeMillis(); + FST originalFst = new FST<>(in, CharSequenceOutputs.getSingleton()); + long endTimeMs = System.currentTimeMillis(); + System.out.println("time = " + (endTimeMs - startTimeMs) + " ms"); + + for (float oversizingFactor : List.of(0f, 0f, 0f, 1f, 1f, 1f)) { + System.out.println("\nFST construction (oversizingFactor=" + oversizingFactor + ")"); + startTimeMs = System.currentTimeMillis(); + FST fst = recompile(originalFst, oversizingFactor); + endTimeMs = System.currentTimeMillis(); + System.out.println("time = " + (endTimeMs - startTimeMs) + " ms"); + System.out.println("FST RAM = " + fst.ramBytesUsed() + " B"); + + System.out.println("FST enum"); + startTimeMs = System.currentTimeMillis(); + walk(fst); + endTimeMs = System.currentTimeMillis(); + System.out.println("time = " + (endTimeMs - startTimeMs) + " ms"); + } + } + } + + private static InputStream newInputStream(Path path) throws IOException { + InputStream in = Files.newInputStream(path); + String fileName = path.getFileName().toString(); + if (fileName.endsWith("gz") || fileName.endsWith("zip")) { + in = new GZIPInputStream(in); + } + return in; + } + + private static FST recompile(FST fst, float oversizingFactor) throws IOException { + FSTCompiler fstCompiler = new FSTCompiler.Builder<>(FST.INPUT_TYPE.BYTE4, CharSequenceOutputs.getSingleton()) + .directAddressingMaxOversizingFactor(oversizingFactor) + .build(); + IntsRefFSTEnum fstEnum = new IntsRefFSTEnum<>(fst); + IntsRefFSTEnum.InputOutput inputOutput; + while ((inputOutput = fstEnum.next()) != null) { + fstCompiler.add(inputOutput.input, CharsRef.deepCopyOf(inputOutput.output)); + } + return fstCompiler.compile(); + } + + private static int walk(FST read) throws IOException { + IntsRefFSTEnum fstEnum = new IntsRefFSTEnum<>(read); + IntsRefFSTEnum.InputOutput inputOutput; + int terms = 0; + while ((inputOutput = fstEnum.next()) != null) { + terms += inputOutput.input.length; + terms += inputOutput.output.length; + } + return terms; + } }