LUCENE-9286: FST.Arc.BitTable reads directly FST bytes. Arc is lightweight again and FSTEnum traversal faster.

This commit is contained in:
Bruno Roustant 2020-04-09 10:36:37 +02:00
parent 4f92cd414c
commit 6bba35a709
No known key found for this signature in database
GPG Key ID: CD28DABB95360525
10 changed files with 548 additions and 477 deletions

View File

@ -144,6 +144,9 @@ Optimizations
* LUCENE-9287: UsageTrackingQueryCachingPolicy no longer caches DocValuesFieldExistsQuery. (Ignacio Vera)
* LUCENE-9286: FST.Arc.BitTable reads directly FST bytes. Arc is lightweight again and FSTEnum traversal faster.
(Bruno Roustant)
Bug Fixes
---------------------
* LUCENE-9259: Fix wrong NGramFilterFactory argument name for preserveOriginal option (Paul Pazderski)

View File

@ -155,7 +155,7 @@ public final class BitUtil {
/**
* flip flops odd with even bits
*/
public static final long flipFlop(final long b) {
public static long flipFlop(final long b) {
return ((b & MAGIC6) >>> 1) | ((b & MAGIC0) << 1 );
}
@ -183,130 +183,4 @@ public final class BitUtil {
public static long zigZagDecode(long l) {
return ((l >>> 1) ^ -(l & 1));
}
/**
* Returns whether the bit at given zero-based index is set.
* <br>Example: bitIndex 66 means the third bit on the right of the second long.
*
* @param bits The bits stored in an array of long for efficiency.
* @param numLongs The number of longs in {@code bits} to consider.
* @param bitIndex The bit zero-based index. It must be greater than or equal to 0,
* and strictly less than {@code numLongs * Long.SIZE}.
*/
public static boolean isBitSet(long[] bits, int numLongs, int bitIndex) {
assert numLongs >= 0 && numLongs <= bits.length && bitIndex >= 0 && bitIndex < numLongs * Long.SIZE
: "bitIndex=" + bitIndex + " numLongs=" + numLongs + " bits.length=" + bits.length;
return (bits[bitIndex / Long.SIZE] & (1L << bitIndex)) != 0; // Shifts are mod 64.
}
/**
* Counts all bits set in the provided longs.
*
* @param bits The bits stored in an array of long for efficiency.
* @param numLongs The number of longs in {@code bits} to consider.
*/
public static int countBits(long[] bits, int numLongs) {
assert numLongs >= 0 && numLongs <= bits.length
: "numLongs=" + numLongs + " bits.length=" + bits.length;
int bitCount = 0;
for (int i = 0; i < numLongs; i++) {
bitCount += Long.bitCount(bits[i]);
}
return bitCount;
}
/**
* Counts the bits set up to the given bit zero-based index, exclusive.
* <br>In other words, how many 1s there are up to the bit at the given index excluded.
* <br>Example: bitIndex 66 means the third bit on the right of the second long.
*
* @param bits The bits stored in an array of long for efficiency.
* @param numLongs The number of longs in {@code bits} to consider.
* @param bitIndex The bit zero-based index, exclusive. It must be greater than or equal to 0,
* and less than or equal to {@code numLongs * Long.SIZE}.
*/
public static int countBitsUpTo(long[] bits, int numLongs, int bitIndex) {
assert numLongs >= 0 && numLongs <= bits.length && bitIndex >= 0 && bitIndex <= numLongs * Long.SIZE
: "bitIndex=" + bitIndex + " numLongs=" + numLongs + " bits.length=" + bits.length;
int bitCount = 0;
int lastLong = bitIndex / Long.SIZE;
for (int i = 0; i < lastLong; i++) {
// Count the bits set for all plain longs.
bitCount += Long.bitCount(bits[i]);
}
if (lastLong < numLongs) {
// Prepare a mask with 1s on the right up to bitIndex exclusive.
long mask = (1L << bitIndex) - 1L; // Shifts are mod 64.
// Count the bits set only within the mask part, so up to bitIndex exclusive.
bitCount += Long.bitCount(bits[lastLong] & mask);
}
return bitCount;
}
/**
* Returns the index of the next bit set following the given bit zero-based index.
* <br>For example with bits 100011:
* the next bit set after index=-1 is at index=0;
* the next bit set after index=0 is at index=1;
* the next bit set after index=1 is at index=5;
* there is no next bit set after index=5.
*
* @param bits The bits stored in an array of long for efficiency.
* @param numLongs The number of longs in {@code bits} to consider.
* @param bitIndex The bit zero-based index. It must be greater than or equal to -1,
* and strictly less than {@code numLongs * Long.SIZE}.
* @return The zero-based index of the next bit set after the provided {@code bitIndex};
* or -1 if none.
*/
public static int nextBitSet(long[] bits, int numLongs, int bitIndex) {
assert numLongs >= 0 && numLongs <= bits.length && bitIndex >= -1 && bitIndex < numLongs * Long.SIZE
: "bitIndex=" + bitIndex + " numLongs=" + numLongs + " bits.length=" + bits.length;
int longIndex = bitIndex / Long.SIZE;
// Prepare a mask with 1s on the left down to bitIndex exclusive.
long mask = -(1L << (bitIndex + 1)); // Shifts are mod 64.
long l = mask == -1 && bitIndex != -1 ? 0 : bits[longIndex] & mask;
while (l == 0) {
if (++longIndex == numLongs) {
return -1;
}
l = bits[longIndex];
}
return Long.numberOfTrailingZeros(l) + longIndex * 64;
}
/**
* Returns the index of the previous bit set preceding the given bit zero-based index.
* <br>For example with bits 100011:
* there is no previous bit set before index=0.
* the previous bit set before index=1 is at index=0;
* the previous bit set before index=5 is at index=1;
* the previous bit set before index=64 is at index=5;
*
* @param bits The bits stored in an array of long for efficiency.
* @param numLongs The number of longs in {@code bits} to consider.
* @param bitIndex The bit zero-based index. It must be greater than or equal to 0,
* and less than or equal to {@code numLongs * Long.SIZE}.
* @return The zero-based index of the previous bit set before the provided {@code bitIndex};
* or -1 if none.
*/
public static int previousBitSet(long[] bits, int numLongs, int bitIndex) {
assert numLongs >= 0 && numLongs <= bits.length && bitIndex >= 0 && bitIndex <= numLongs * Long.SIZE
: "bitIndex=" + bitIndex + " numLongs=" + numLongs + " bits.length=" + bits.length;
int longIndex = bitIndex / Long.SIZE;
long l;
if (longIndex == numLongs) {
l = 0;
} else {
// Prepare a mask with 1s on the right up to bitIndex exclusive.
long mask = (1L << bitIndex) - 1L; // Shifts are mod 64.
l = bits[longIndex] & mask;
}
while (l == 0) {
if (longIndex-- == 0) {
return -1;
}
l = bits[longIndex];
}
return 63 - Long.numberOfLeadingZeros(l) + longIndex * 64;
}
}

View File

@ -0,0 +1,179 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.fst;
import java.io.IOException;
/**
* Static helper methods for {@link FST.Arc.BitTable}.
*
* @lucene.experimental
*/
class BitTableUtil {
/**
* Returns whether the bit at given zero-based index is set.
* <br>Example: bitIndex 10 means the third bit on the right of the second byte.
*
* @param bitIndex The bit zero-based index. It must be greater than or equal to 0, and strictly less than
* {@code number of bit-table bytes * Byte.SIZE}.
* @param reader The {@link FST.BytesReader} to read. It must be positioned at the beginning of the bit-table.
*/
static boolean isBitSet(int bitIndex, FST.BytesReader reader) throws IOException {
assert bitIndex >= 0 : "bitIndex=" + bitIndex;
reader.skipBytes(bitIndex >> 3);
return (readByte(reader) & (1L << (bitIndex & (Byte.SIZE - 1)))) != 0;
}
/**
* Counts all bits set in the bit-table.
*
* @param bitTableBytes The number of bytes in the bit-table.
* @param reader The {@link FST.BytesReader} to read. It must be positioned at the beginning of the bit-table.
*/
static int countBits(int bitTableBytes, FST.BytesReader reader) throws IOException {
assert bitTableBytes >= 0 : "bitTableBytes=" + bitTableBytes;
int bitCount = 0;
for (int i = bitTableBytes >> 3; i > 0; i--) {
// Count the bits set for all plain longs.
bitCount += Long.bitCount(read8Bytes(reader));
}
int numRemainingBytes;
if ((numRemainingBytes = bitTableBytes & (Long.BYTES - 1)) != 0) {
bitCount += Long.bitCount(readUpTo8Bytes(numRemainingBytes, reader));
}
return bitCount;
}
/**
* Counts the bits set up to the given bit zero-based index, exclusive.
* <br>In other words, how many 1s there are up to the bit at the given index excluded.
* <br>Example: bitIndex 10 means the third bit on the right of the second byte.
*
* @param bitIndex The bit zero-based index, exclusive. It must be greater than or equal to 0, and less than or equal
* to {@code number of bit-table bytes * Byte.SIZE}.
* @param reader The {@link FST.BytesReader} to read. It must be positioned at the beginning of the bit-table.
*/
static int countBitsUpTo(int bitIndex, FST.BytesReader reader) throws IOException {
assert bitIndex >= 0 : "bitIndex=" + bitIndex;
int bitCount = 0;
for (int i = bitIndex >> 6; i > 0; i--) {
// Count the bits set for all plain longs.
bitCount += Long.bitCount(read8Bytes(reader));
}
int remainingBits;
if ((remainingBits = bitIndex & (Long.SIZE - 1)) != 0) {
int numRemainingBytes = (remainingBits + (Byte.SIZE - 1)) >> 3;
// Prepare a mask with 1s on the right up to bitIndex exclusive.
long mask = (1L << bitIndex) - 1L; // Shifts are mod 64.
// Count the bits set only within the mask part, so up to bitIndex exclusive.
bitCount += Long.bitCount(readUpTo8Bytes(numRemainingBytes, reader) & mask);
}
return bitCount;
}
/**
* Returns the index of the next bit set following the given bit zero-based index.
* <br>For example with bits 100011:
* the next bit set after index=-1 is at index=0;
* the next bit set after index=0 is at index=1;
* the next bit set after index=1 is at index=5;
* there is no next bit set after index=5.
*
* @param bitIndex The bit zero-based index. It must be greater than or equal to -1, and strictly less than
* {@code number of bit-table bytes * Byte.SIZE}.
* @param bitTableBytes The number of bytes in the bit-table.
* @param reader The {@link FST.BytesReader} to read. It must be positioned at the beginning of the bit-table.
* @return The zero-based index of the next bit set after the provided {@code bitIndex}; or -1 if none.
*/
static int nextBitSet(int bitIndex, int bitTableBytes, FST.BytesReader reader) throws IOException {
assert bitIndex >= -1 && bitIndex < bitTableBytes * Byte.SIZE : "bitIndex=" + bitIndex + " bitTableBytes=" + bitTableBytes;
int byteIndex = bitIndex / Byte.SIZE;
int mask = -1 << ((bitIndex + 1) & (Byte.SIZE - 1));
int i;
if (mask == -1 && bitIndex != -1) {
reader.skipBytes(byteIndex + 1);
i = 0;
} else {
reader.skipBytes(byteIndex);
i = (reader.readByte() & 0xFF) & mask;
}
while (i == 0) {
if (++byteIndex == bitTableBytes) {
return -1;
}
i = reader.readByte() & 0xFF;
}
return Integer.numberOfTrailingZeros(i) + (byteIndex << 3);
}
/**
* Returns the index of the previous bit set preceding the given bit zero-based index.
* <br>For example with bits 100011:
* there is no previous bit set before index=0.
* the previous bit set before index=1 is at index=0;
* the previous bit set before index=5 is at index=1;
* the previous bit set before index=64 is at index=5;
*
* @param bitIndex The bit zero-based index. It must be greater than or equal to 0, and less than or equal to
* {@code number of bit-table bytes * Byte.SIZE}.
* @param reader The {@link FST.BytesReader} to read. It must be positioned at the beginning of the bit-table.
* @return The zero-based index of the previous bit set before the provided {@code bitIndex}; or -1 if none.
*/
static int previousBitSet(int bitIndex, FST.BytesReader reader) throws IOException {
assert bitIndex >= 0 : "bitIndex=" + bitIndex;
int byteIndex = bitIndex >> 3;
reader.skipBytes(byteIndex);
int mask = (1 << (bitIndex & (Byte.SIZE - 1))) - 1;
int i = (reader.readByte() & 0xFF) & mask;
while (i == 0) {
if (byteIndex-- == 0) {
return -1;
}
reader.skipBytes(-2); // FST.BytesReader implementations support negative skip.
i = reader.readByte() & 0xFF;
}
return (Integer.SIZE - 1) - Integer.numberOfLeadingZeros(i) + (byteIndex << 3);
}
private static long readByte(FST.BytesReader reader) throws IOException {
return reader.readByte() & 0xFFL;
}
private static long readUpTo8Bytes(int numBytes, FST.BytesReader reader) throws IOException {
assert numBytes > 0 && numBytes <= 8 : "numBytes=" + numBytes;
long l = readByte(reader);
int shift = 0;
while (--numBytes != 0) {
l |= readByte(reader) << (shift += 8);
}
return l;
}
private static long read8Bytes(FST.BytesReader reader) throws IOException {
return readByte(reader)
| readByte(reader) << 8
| readByte(reader) << 16
| readByte(reader) << 24
| readByte(reader) << 32
| readByte(reader) << 40
| readByte(reader) << 48
| readByte(reader) << 56;
}
}

View File

@ -33,10 +33,11 @@ import org.apache.lucene.store.InputStreamDataInput;
import org.apache.lucene.store.OutputStreamDataOutput;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.RamUsageEstimator;
import static org.apache.lucene.util.fst.FST.Arc.BitTable;
// TODO: break this into WritableFST and ReadOnlyFST.. then
// we can have subclasses of ReadOnlyFST to handle the
// different byte[] level encodings (packed or
@ -164,23 +165,36 @@ public final class FST<T> implements Accountable {
private long nextArc;
private int arcIdx;
private byte nodeFlags;
//*** Fields for arcs belonging to a node with fixed length arcs.
// So only valid when bytesPerArc != 0.
private byte nodeFlags;
private long posArcsStart;
// nodeFlags == ARCS_FOR_BINARY_SEARCH || nodeFlags == ARCS_FOR_DIRECT_ADDRESSING.
private int bytesPerArc;
private long posArcsStart;
private int arcIdx;
private int numArcs;
private BitTable bitTable;
//*** Fields for a direct addressing node. nodeFlags == ARCS_FOR_DIRECT_ADDRESSING.
/** Start position in the {@link FST.BytesReader} of the presence bits for a direct addressing node, aka the bit-table */
private long bitTableStart;
/** First label of a direct addressing node. */
private int firstLabel;
/**
* Index of the current label of a direct addressing node. While {@link #arcIdx} is the current index in the label
* range, {@link #presenceIndex} is its corresponding index in the list of actually present labels. It is equal
* to the number of bits set before the bit at {@link #arcIdx} in the bit-table. This field is a cache to avoid
* to count bits set repeatedly when iterating the next arcs.
*/
private int presenceIndex;
/** Returns this */
public Arc<T> copyFrom(Arc<T> other) {
label = other.label();
@ -191,15 +205,18 @@ public final class FST<T> implements Accountable {
nextArc = other.nextArc();
nodeFlags = other.nodeFlags();
bytesPerArc = other.bytesPerArc();
if (bytesPerArc() != 0) {
posArcsStart = other.posArcsStart();
arcIdx = other.arcIdx();
numArcs = other.numArcs();
if (nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING) {
bitTable = other.bitTable() == null ? null : other.bitTable().copy();
firstLabel = other.firstLabel();
}
}
// Fields for arcs belonging to a node with fixed length arcs.
// We could avoid copying them if bytesPerArc() == 0 (this was the case with previous code, and the current code
// still supports that), but it may actually help external uses of FST to have consistent arc state, and debugging
// is easier.
posArcsStart = other.posArcsStart();
arcIdx = other.arcIdx();
numArcs = other.numArcs();
bitTableStart = other.bitTableStart;
firstLabel = other.firstLabel();
presenceIndex = other.presenceIndex;
return this;
}
@ -239,7 +256,8 @@ public final class FST<T> implements Accountable {
b.append(" nextFinalOutput=").append(nextFinalOutput());
}
if (bytesPerArc() != 0) {
b.append(" arcArray(idx=").append(arcIdx()).append(" of ").append(numArcs()).append(")");
b.append(" arcArray(idx=").append(arcIdx()).append(" of ").append(numArcs()).append(")")
.append("(").append(nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING ? "da" : "bs").append(")");
}
return b.toString();
}
@ -303,21 +321,6 @@ public final class FST<T> implements Accountable {
return numArcs;
}
/** Table of bits of a direct addressing node.
* Only valid if nodeFlags == {@link #ARCS_FOR_DIRECT_ADDRESSING};
* may be null otherwise. */
BitTable bitTable() {
return bitTable;
}
/** The table of bits of a direct addressing node created lazily. */
BitTable getOrCreateBitTable() {
if (bitTable == null) {
bitTable = new BitTable();
}
return bitTable;
}
/** First label of a direct addressing node.
* Only valid if nodeFlags == {@link #ARCS_FOR_DIRECT_ADDRESSING}. */
int firstLabel() {
@ -325,65 +328,63 @@ public final class FST<T> implements Accountable {
}
/**
* Reusable table of bits using an array of long internally.
* Helper methods to read the bit-table of a direct addressing node.
* Only valid for {@link Arc} with {@link Arc#nodeFlags()} == {@code ARCS_FOR_DIRECT_ADDRESSING}.
*/
static class BitTable {
private long[] bits;
private int numLongs;
/** Sets the number of longs in the internal long array.
* Enlarges it if needed. Always clears the array. */
BitTable setNumLongs(int numLongs) {
assert numLongs >= 0;
this.numLongs = numLongs;
if (bits == null || bits.length < numLongs) {
bits = new long[ArrayUtil.oversize(numLongs, Long.BYTES)];
} else {
for (int i = 0; i < numLongs; i++) {
bits[i] = 0L;
}
}
return this;
/** See {@link BitTableUtil#isBitSet(int, FST.BytesReader)}. */
static boolean isBitSet(int bitIndex, Arc arc, FST.BytesReader in) throws IOException {
assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING;
in.setPosition(arc.bitTableStart);
return BitTableUtil.isBitSet(bitIndex, in);
}
/** Creates a new {@link BitTable} by copying this one. */
BitTable copy() {
BitTable bitTable = new BitTable();
bitTable.bits = ArrayUtil.copyOfSubArray(bits, 0, bits.length);
bitTable.numLongs = numLongs;
return bitTable;
/**
* See {@link BitTableUtil#countBits(int, FST.BytesReader)}.
* The count of bit set is the number of arcs of a direct addressing node.
*/
static int countBits(Arc arc, FST.BytesReader in) throws IOException {
assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING;
in.setPosition(arc.bitTableStart);
return BitTableUtil.countBits(getNumPresenceBytes(arc.numArcs()), in);
}
boolean assertIsValid() {
assert numLongs > 0 && numLongs <= bits.length;
/** See {@link BitTableUtil#countBitsUpTo(int, FST.BytesReader)}. */
static int countBitsUpTo(int bitIndex, Arc arc, FST.BytesReader in) throws IOException {
assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING;
in.setPosition(arc.bitTableStart);
return BitTableUtil.countBitsUpTo(bitIndex, in);
}
/** See {@link BitTableUtil#nextBitSet(int, int, FST.BytesReader)}. */
static int nextBitSet(int bitIndex, Arc arc, FST.BytesReader in) throws IOException {
assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING;
in.setPosition(arc.bitTableStart);
return BitTableUtil.nextBitSet(bitIndex, getNumPresenceBytes(arc.numArcs()), in);
}
/** See {@link BitTableUtil#previousBitSet(int, FST.BytesReader)}. */
static int previousBitSet(int bitIndex, Arc arc, FST.BytesReader in) throws IOException {
assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING;
in.setPosition(arc.bitTableStart);
return BitTableUtil.previousBitSet(bitIndex, in);
}
/**
* Asserts the bit-table of the provided {@link Arc} is valid.
*/
static boolean assertIsValid(Arc arc, FST.BytesReader in) throws IOException {
assert arc.bytesPerArc() > 0;
assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING;
// First bit must be set.
assert isBitSet(0, arc, in);
// Last bit must be set.
assert isBitSet(arc.numArcs() - 1, arc, in);
// No bit set after the last arc.
assert nextBitSet(arc.numArcs() - 1, arc, in) == -1;
return true;
}
/** Forwards to {@link BitUtil#isBitSet(long[], int, int)}. */
boolean isBitSet(int bitIndex) {
return BitUtil.isBitSet(bits, numLongs, bitIndex);
}
/** Forwards to {@link BitUtil#countBits(long[], int)}. */
int countBits() {
return BitUtil.countBits(bits, numLongs);
}
/** Forwards to {@link BitUtil#countBitsUpTo(long[], int, int)}. */
int countBitsUpTo(int bitIndex) {
return BitUtil.countBitsUpTo(bits, numLongs, bitIndex);
}
/** Forwards to {@link BitUtil#nextBitSet(long[], int, int)}. */
int nextBitSet(int bitIndex) {
return BitUtil.nextBitSet(bits, numLongs, bitIndex);
}
/** Forwards to {@link BitUtil#previousBitSet(long[], int, int)}. */
int previousBitSet(int bitIndex) {
return BitUtil.previousBitSet(bits, numLongs, bitIndex);
}
}
}
@ -921,41 +922,19 @@ public final class FST<T> implements Accountable {
/** Gets the number of bytes required to flag the presence of each arc in the given label range, one bit per arc. */
private static int getNumPresenceBytes(int labelRange) {
return (labelRange + 7) / Byte.SIZE;
assert labelRange >= 0;
return (labelRange + 7) >> 3;
}
/**
* Reads the presence bits of a direct-addressing node, store them in the provided arc {@link Arc#bitTable()}
* and returns the number of presence bytes.
* Reads the presence bits of a direct-addressing node.
* Actually we don't read them here, we just keep the pointer to the bit-table start and we skip them.
*/
private int readPresenceBytes(Arc<T> arc, BytesReader in) throws IOException {
int numPresenceBytes = getNumPresenceBytes(arc.numArcs());
Arc.BitTable presenceBits = arc.getOrCreateBitTable().setNumLongs((numPresenceBytes + 7) / Long.BYTES);
for (int i = 0; i < numPresenceBytes; i++) {
// Read the next unsigned byte, shift it to the left, and appends it to the current long.
presenceBits.bits[i / Long.BYTES] |= (in.readByte() & 0xFFL) << (i * Byte.SIZE);
}
assert assertPresenceBytesAreValid(arc);
return numPresenceBytes;
}
private int getNumArcsDirectAddressing(Arc<T> arc) {
private void readPresenceBytes(Arc<T> arc, BytesReader in) throws IOException {
assert arc.bytesPerArc() > 0;
assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING;
return arc.bitTable().countBits();
}
private boolean assertPresenceBytesAreValid(Arc<T> arc) {
assert arc.bitTable() != null;
assert arc.bitTable().assertIsValid();
// First bit must be set.
assert arc.bitTable().isBitSet(0);
// Last bit must be set.
assert arc.bitTable().isBitSet(arc.numArcs() - 1);
// No bit set after the last arc.
assert arc.bitTable().nextBitSet(arc.numArcs() - 1) == -1;
// Total bit set (real num arcs) must be <= label range (stored in arc.numArcs()).
assert getNumArcsDirectAddressing(arc) <= arc.numArcs();
return true;
arc.bitTableStart = in.getPosition();
in.skipBytes(getNumPresenceBytes(arc.numArcs()));
}
/** Fills virtual 'start' arc, ie, an empty incoming arc to the FST's start node */
@ -1010,7 +989,7 @@ public final class FST<T> implements Accountable {
readPresenceBytes(arc, in);
arc.firstLabel = readLabel(in);
arc.posArcsStart = in.getPosition();
readArcByDirectAddressing(arc, in, arc.numArcs() - 1);
readLastArcByDirectAddressing(arc, in);
} else {
arc.arcIdx = arc.numArcs() - 2;
arc.posArcsStart = in.getPosition();
@ -1095,6 +1074,7 @@ public final class FST<T> implements Accountable {
if (flags == ARCS_FOR_DIRECT_ADDRESSING) {
readPresenceBytes(arc, in);
arc.firstLabel = readLabel(in);
arc.presenceIndex = -1;
}
arc.posArcsStart = in.getPosition();
//System.out.println(" bytesPer=" + arc.bytesPerArc + " numArcs=" + arc.numArcs + " arcsStart=" + pos);
@ -1166,9 +1146,9 @@ public final class FST<T> implements Accountable {
assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING;
// Direct addressing node. The label is not stored but rather inferred
// based on first label and arc index in the range.
assert assertPresenceBytesAreValid(arc);
assert arc.bitTable().isBitSet(arc.arcIdx());
int nextIndex = arc.bitTable().nextBitSet(arc.arcIdx());
assert BitTable.assertIsValid(arc, in);
assert BitTable.isBitSet(arc.arcIdx(), arc, in);
int nextIndex = BitTable.nextBitSet(arc.arcIdx(), arc, in);
assert nextIndex != -1;
return arc.firstLabel() + nextIndex;
}
@ -1183,6 +1163,8 @@ public final class FST<T> implements Accountable {
}
public Arc<T> readArcByIndex(Arc<T> arc, final BytesReader in, int idx) throws IOException {
assert arc.bytesPerArc() > 0;
assert arc.nodeFlags() == ARCS_FOR_BINARY_SEARCH;
assert idx >= 0 && idx < arc.numArcs();
in.setPosition(arc.posArcsStart() - idx * arc.bytesPerArc());
arc.arcIdx = idx;
@ -1190,25 +1172,44 @@ public final class FST<T> implements Accountable {
return readArc(arc, in);
}
/** Reads a present direct addressing node arc, with the provided index in the label range.
/**
* Reads a present direct addressing node arc, with the provided index in the label range.
*
* @param rangeIndex The index of the arc in the label range. It must be present.
* The real arc offset is computed based on the presence bits of
* the direct addressing node.
*/
public Arc<T> readArcByDirectAddressing(Arc<T> arc, final BytesReader in, int rangeIndex) throws IOException {
assert arc.nodeFlags() == ARCS_FOR_DIRECT_ADDRESSING;
assert arc.bytesPerArc() > 0;
assert assertPresenceBytesAreValid(arc);
assert BitTable.assertIsValid(arc, in);
assert rangeIndex >= 0 && rangeIndex < arc.numArcs();
assert arc.bitTable().isBitSet(rangeIndex);
int presenceIndex = arc.bitTable().countBitsUpTo(rangeIndex);
assert BitTable.isBitSet(rangeIndex, arc, in);
int presenceIndex = BitTable.countBitsUpTo(rangeIndex, arc, in);
return readArcByDirectAddressing(arc, in, rangeIndex, presenceIndex);
}
/**
* Reads a present direct addressing node arc, with the provided index in the label range and its corresponding
* presence index (which is the count of presence bits before it).
*/
private Arc<T> readArcByDirectAddressing(Arc<T> arc, final BytesReader in, int rangeIndex, int presenceIndex) throws IOException {
in.setPosition(arc.posArcsStart() - presenceIndex * arc.bytesPerArc());
arc.arcIdx = rangeIndex;
arc.presenceIndex = presenceIndex;
arc.flags = in.readByte();
return readArc(arc, in);
}
/**
* Reads the last arc of a direct addressing node.
* This method is equivalent to call {@link #readArcByDirectAddressing(Arc, BytesReader, int)} with {@code rangeIndex}
* equal to {@code arc.numArcs() - 1}, but it is faster.
*/
public Arc<T> readLastArcByDirectAddressing(Arc<T> arc, final BytesReader in) throws IOException {
assert BitTable.assertIsValid(arc, in);
int presenceIndex = BitTable.countBits(arc, in) - 1;
return readArcByDirectAddressing(arc, in, arc.numArcs() - 1, presenceIndex);
}
/** Never returns null, but you should never call this if
* arc.isLast() is true. */
public Arc<T> readNextRealArc(Arc<T> arc, final BytesReader in) throws IOException {
@ -1227,11 +1228,10 @@ public final class FST<T> implements Accountable {
break;
case ARCS_FOR_DIRECT_ADDRESSING:
assert arc.bytesPerArc() > 0;
assert assertPresenceBytesAreValid(arc);
assert arc.arcIdx() == -1 || arc.bitTable().isBitSet(arc.arcIdx());
int nextIndex = arc.bitTable().nextBitSet(arc.arcIdx());
return readArcByDirectAddressing(arc, in, nextIndex);
assert BitTable.assertIsValid(arc, in);
assert arc.arcIdx() == -1 || BitTable.isBitSet(arc.arcIdx(), arc, in);
int nextIndex = BitTable.nextBitSet(arc.arcIdx(), arc, in);
return readArcByDirectAddressing(arc, in, nextIndex, arc.presenceIndex + 1);
default:
// Variable length arcs - linear search.
@ -1282,7 +1282,7 @@ public final class FST<T> implements Accountable {
// must scan
seekToNextNode(in);
} else {
int numArcs = arc.nodeFlags == ARCS_FOR_DIRECT_ADDRESSING ? getNumArcsDirectAddressing(arc) : arc.numArcs();
int numArcs = arc.nodeFlags == ARCS_FOR_DIRECT_ADDRESSING ? BitTable.countBits(arc, in) : arc.numArcs();
in.setPosition(arc.posArcsStart() - arc.bytesPerArc() * numArcs);
}
}
@ -1355,7 +1355,7 @@ public final class FST<T> implements Accountable {
int arcIndex = labelToMatch - arc.firstLabel();
if (arcIndex < 0 || arcIndex >= arc.numArcs()) {
return null; // Before or after label range.
} else if (!arc.bitTable().isBitSet(arcIndex)) {
} else if (!BitTable.isBitSet(arcIndex, arc, in)) {
return null; // Arc missing in the range.
}
return readArcByDirectAddressing(arc, in, arcIndex);
@ -1455,113 +1455,4 @@ public final class FST<T> implements Accountable {
* under-the-hood. */
public abstract boolean reversed();
}
/*
public void countSingleChains() throws IOException {
// TODO: must assert this FST was built with
// "willRewrite"
final List<ArcAndState<T>> queue = new ArrayList<>();
// TODO: use bitset to not revisit nodes already
// visited
FixedBitSet seen = new FixedBitSet(1+nodeCount);
int saved = 0;
queue.add(new ArcAndState<T>(getFirstArc(new Arc<T>()), new IntsRef()));
Arc<T> scratchArc = new Arc<>();
while(queue.size() > 0) {
//System.out.println("cycle size=" + queue.size());
//for(ArcAndState<T> ent : queue) {
// System.out.println(" " + Util.toBytesRef(ent.chain, new BytesRef()));
// }
final ArcAndState<T> arcAndState = queue.get(queue.size()-1);
seen.set(arcAndState.arc.node);
final BytesRef br = Util.toBytesRef(arcAndState.chain, new BytesRef());
if (br.length > 0 && br.bytes[br.length-1] == -1) {
br.length--;
}
//System.out.println(" top node=" + arcAndState.arc.target + " chain=" + br.utf8ToString());
if (targetHasArcs(arcAndState.arc) && !seen.get(arcAndState.arc.target)) {
// push
readFirstTargetArc(arcAndState.arc, scratchArc);
//System.out.println(" push label=" + (char) scratchArc.label);
//System.out.println(" tonode=" + scratchArc.target + " last?=" + scratchArc.isLast());
final IntsRef chain = IntsRef.deepCopyOf(arcAndState.chain);
chain.grow(1+chain.length);
// TODO
//assert scratchArc.label != END_LABEL;
chain.ints[chain.length] = scratchArc.label;
chain.length++;
if (scratchArc.isLast()) {
if (scratchArc.target != -1 && inCounts[scratchArc.target] == 1) {
//System.out.println(" append");
} else {
if (arcAndState.chain.length > 1) {
saved += chain.length-2;
try {
System.out.println("chain: " + Util.toBytesRef(chain, new BytesRef()).utf8ToString());
} catch (AssertionError ae) {
System.out.println("chain: " + Util.toBytesRef(chain, new BytesRef()));
}
}
chain.length = 0;
}
} else {
//System.out.println(" reset");
if (arcAndState.chain.length > 1) {
saved += arcAndState.chain.length-2;
try {
System.out.println("chain: " + Util.toBytesRef(arcAndState.chain, new BytesRef()).utf8ToString());
} catch (AssertionError ae) {
System.out.println("chain: " + Util.toBytesRef(arcAndState.chain, new BytesRef()));
}
}
if (scratchArc.target != -1 && inCounts[scratchArc.target] != 1) {
chain.length = 0;
} else {
chain.ints[0] = scratchArc.label;
chain.length = 1;
}
}
// TODO: instead of new Arc() we can re-use from
// a by-depth array
queue.add(new ArcAndState<T>(new Arc<T>().copyFrom(scratchArc), chain));
} else if (!arcAndState.arc.isLast()) {
// next
readNextArc(arcAndState.arc);
//System.out.println(" next label=" + (char) arcAndState.arc.label + " len=" + arcAndState.chain.length);
if (arcAndState.chain.length != 0) {
arcAndState.chain.ints[arcAndState.chain.length-1] = arcAndState.arc.label;
}
} else {
if (arcAndState.chain.length > 1) {
saved += arcAndState.chain.length-2;
System.out.println("chain: " + Util.toBytesRef(arcAndState.chain, new BytesRef()).utf8ToString());
}
// pop
//System.out.println(" pop");
queue.remove(queue.size()-1);
while(queue.size() > 0 && queue.get(queue.size()-1).arc.isLast()) {
queue.remove(queue.size()-1);
}
if (queue.size() > 0) {
final ArcAndState<T> arcAndState2 = queue.get(queue.size()-1);
readNextArc(arcAndState2.arc);
//System.out.println(" read next=" + (char) arcAndState2.arc.label + " queue=" + queue.size());
assert arcAndState2.arc.label != END_LABEL;
if (arcAndState2.chain.length != 0) {
arcAndState2.chain.ints[arcAndState2.chain.length-1] = arcAndState2.arc.label;
}
}
}
}
System.out.println("TOT saved " + saved);
}
*/
}

View File

@ -22,6 +22,8 @@ import java.io.IOException;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.RamUsageEstimator;
import static org.apache.lucene.util.fst.FST.Arc.BitTable;
/** Can next() and advance() through the terms in an FST
*
* @lucene.experimental
@ -36,7 +38,6 @@ abstract class FSTEnum<T> {
protected final T NO_OUTPUT;
protected final FST.BytesReader fstReader;
protected final FST.Arc<T> scratchArc = new FST.Arc<>();
protected int upto;
int targetLength;
@ -178,7 +179,7 @@ abstract class FSTEnum<T> {
} else {
if (targetIndex < 0) {
targetIndex = -1;
} else if (arc.bitTable().isBitSet(targetIndex)) {
} else if (BitTable.isBitSet(targetIndex, arc, in)) {
fst.readArcByDirectAddressing(arc, in, targetIndex);
assert arc.label() == targetLabel;
// found -- copy pasta from below
@ -191,7 +192,7 @@ abstract class FSTEnum<T> {
return fst.readFirstTargetArc(arc, getArc(upto), fstReader);
}
// Not found, return the next arc (ceil).
int ceilIndex = arc.bitTable().nextBitSet(targetIndex);
int ceilIndex = BitTable.nextBitSet(targetIndex, arc, in);
assert ceilIndex != -1;
fst.readArcByDirectAddressing(arc, in, ceilIndex);
assert arc.label() > targetLabel;
@ -335,14 +336,14 @@ abstract class FSTEnum<T> {
return backtrackToFloorArc(arc, targetLabel, in);
} else if (targetIndex >= arc.numArcs()) {
// After last arc.
fst.readArcByDirectAddressing(arc, in, arc.numArcs() - 1);
fst.readLastArcByDirectAddressing(arc, in);
assert arc.label() < targetLabel;
assert arc.isLast();
pushLast();
return null;
} else {
// Within label range.
if (arc.bitTable().isBitSet(targetIndex)) {
if (BitTable.isBitSet(targetIndex, arc, in)) {
fst.readArcByDirectAddressing(arc, in, targetIndex);
assert arc.label() == targetLabel;
// found -- copy pasta from below
@ -355,7 +356,7 @@ abstract class FSTEnum<T> {
return fst.readFirstTargetArc(arc, getArc(upto), fstReader);
}
// Scan backwards to find a floor arc.
int floorIndex = arc.bitTable().previousBitSet(targetIndex);
int floorIndex = BitTable.previousBitSet(targetIndex, arc, in);
assert floorIndex != -1;
fst.readArcByDirectAddressing(arc, in, floorIndex);
assert arc.label() < targetLabel;
@ -421,10 +422,10 @@ abstract class FSTEnum<T> {
assert targetIndex >= 0;
if (targetIndex >= arc.numArcs()) {
// Beyond last arc. Take last arc.
fst.readArcByDirectAddressing(arc, in, arc.numArcs() - 1);
fst.readLastArcByDirectAddressing(arc, in);
} else {
// Take the preceding arc, even if the target is present.
int floorIndex = arc.bitTable().previousBitSet(targetIndex);
int floorIndex = BitTable.previousBitSet(targetIndex, arc, in);
if (floorIndex > 0) {
fst.readArcByDirectAddressing(arc, in, floorIndex);
}

View File

@ -51,7 +51,7 @@ final class NodeHash<T> {
} else {
assert scratchArc.nodeFlags() == FST.ARCS_FOR_DIRECT_ADDRESSING;
if ((node.arcs[node.numArcs - 1].label - node.arcs[0].label + 1) != scratchArc.numArcs()
|| node.numArcs != scratchArc.bitTable().countBits()) {
|| node.numArcs != FST.Arc.BitTable.countBits(scratchArc, in)) {
return false;
}
}

View File

@ -33,6 +33,8 @@ import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import static org.apache.lucene.util.fst.FST.Arc.BitTable;
/** Static helper methods.
*
* @lucene.experimental */
@ -478,6 +480,7 @@ public final class Util {
// For each arc leaving this node:
boolean foundZero = false;
boolean arcCopyIsPending = false;
while(true) {
// tricky: instead of comparing output == 0, we must
// express it via the comparator compare(output, 0) == 0
@ -486,7 +489,7 @@ public final class Util {
foundZero = true;
break;
} else if (!foundZero) {
scratchArc.copyFrom(path.arc);
arcCopyIsPending = true;
foundZero = true;
} else {
addIfCompetitive(path);
@ -497,16 +500,16 @@ public final class Util {
if (path.arc.isLast()) {
break;
}
if (arcCopyIsPending) {
scratchArc.copyFrom(path.arc);
arcCopyIsPending = false;
}
fst.readNextArc(path.arc, fstReader);
}
assert foundZero;
if (queue != null) {
// TODO: maybe we can save this copyFrom if we
// are more clever above... eg on finding the
// first NO_OUTPUT arc we'd switch to using
// scratchArc
if (queue != null && !arcCopyIsPending) {
path.arc.copyFrom(scratchArc);
}
@ -948,11 +951,11 @@ public final class Util {
} else if (targetIndex < 0) {
return arc;
} else {
if (arc.bitTable().isBitSet(targetIndex)) {
if (BitTable.isBitSet(targetIndex, arc, in)) {
fst.readArcByDirectAddressing(arc, in, targetIndex);
assert arc.label() == label;
} else {
int ceilIndex = arc.bitTable().nextBitSet(targetIndex);
int ceilIndex = BitTable.nextBitSet(targetIndex, arc, in);
assert ceilIndex != -1;
fst.readArcByDirectAddressing(arc, in, ceilIndex);
assert arc.label() > label;

View File

@ -1,87 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util;
public class TestBitUtil extends LuceneTestCase {
public void testNextBitSet() {
int numIterations = atLeast(1000);
for (int i = 0; i < numIterations; i++) {
long[] bits = buildRandomBits();
int numLong = bits.length - 1;
// Verify nextBitSet with countBitsUpTo for all bit indexes.
for (int bitIndex = -1; bitIndex < 64 * numLong; bitIndex++) {
int nextIndex = BitUtil.nextBitSet(bits, numLong, bitIndex);
if (nextIndex == -1) {
assertEquals("No next bit set, so expected no bit count diff"
+ " (i=" + i + " bitIndex=" + bitIndex + ")",
BitUtil.countBitsUpTo(bits, numLong, bitIndex + 1), BitUtil.countBits(bits, numLong));
} else {
assertTrue("Expected next bit set at nextIndex=" + nextIndex
+ " (i=" + i + " bitIndex=" + bitIndex + ")",
BitUtil.isBitSet(bits, numLong, nextIndex));
assertEquals("Next bit set at nextIndex=" + nextIndex
+ " so expected bit count diff of 1"
+ " (i=" + i + " bitIndex=" + bitIndex + ")",
BitUtil.countBitsUpTo(bits, numLong, bitIndex + 1) + 1,
BitUtil.countBitsUpTo(bits, numLong, nextIndex + 1));
}
}
}
}
public void testPreviousBitSet() {
int numIterations = atLeast(1000);
for (int i = 0; i < numIterations; i++) {
long[] bits = buildRandomBits();
int numLong = bits.length - 1;
// Verify previousBitSet with countBitsUpTo for all bit indexes.
for (int bitIndex = 0; bitIndex <= 64 * numLong; bitIndex++) {
int previousIndex = BitUtil.previousBitSet(bits, numLong, bitIndex);
if (previousIndex == -1) {
assertEquals("No previous bit set, so expected bit count 0"
+ " (i=" + i + " bitIndex=" + bitIndex + ")",
0, BitUtil.countBitsUpTo(bits, numLong, bitIndex));
} else {
assertTrue("Expected previous bit set at previousIndex=" + previousIndex
+ " (i=" + i + " bitIndex=" + bitIndex + ")",
BitUtil.isBitSet(bits, numLong, previousIndex));
int bitCount = BitUtil.countBitsUpTo(bits, numLong, Math.min(bitIndex + 1, numLong * Long.SIZE));
int expectedPreviousBitCount = bitIndex < numLong * Long.SIZE && BitUtil.isBitSet(bits, numLong, bitIndex) ?
bitCount - 1 : bitCount;
assertEquals("Previous bit set at previousIndex=" + previousIndex
+ " with current bitCount=" + bitCount
+ " so expected previousBitCount=" + expectedPreviousBitCount
+ " (i=" + i + " bitIndex=" + bitIndex + ")",
expectedPreviousBitCount, BitUtil.countBitsUpTo(bits, numLong, previousIndex + 1));
}
}
}
}
private long[] buildRandomBits() {
long[] bits = new long[random().nextInt(3) + 2];
for (int j = 0; j < bits.length; j++) {
// Bias towards zeros which require special logic.
bits[j] = random().nextInt(4) == 0 ? 0L : random().nextLong();
}
return bits;
}
}

View File

@ -0,0 +1,138 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.fst;
import java.io.IOException;
import org.apache.lucene.util.LuceneTestCase;
public class TestBitTableUtil extends LuceneTestCase {
public void testNextBitSet() throws IOException {
int numIterations = atLeast(1000);
for (int i = 0; i < numIterations; i++) {
byte[] bits = buildRandomBits();
int numBytes = bits.length - 1;
int numBits = numBytes * Byte.SIZE;
// Verify nextBitSet with countBitsUpTo for all bit indexes.
for (int bitIndex = -1; bitIndex < numBits; bitIndex++) {
int nextIndex = BitTableUtil.nextBitSet(bitIndex, numBytes, reader(bits));
if (nextIndex == -1) {
assertEquals("No next bit set, so expected no bit count diff"
+ " (i=" + i + " bitIndex=" + bitIndex + ")",
BitTableUtil.countBitsUpTo(bitIndex + 1, reader(bits)),
BitTableUtil.countBits(numBytes, reader(bits)));
} else {
assertTrue("Expected next bit set at nextIndex=" + nextIndex
+ " (i=" + i + " bitIndex=" + bitIndex + ")",
BitTableUtil.isBitSet(nextIndex, reader(bits)));
assertEquals("Next bit set at nextIndex=" + nextIndex
+ " so expected bit count diff of 1"
+ " (i=" + i + " bitIndex=" + bitIndex + ")",
BitTableUtil.countBitsUpTo(bitIndex + 1, reader(bits)) + 1,
BitTableUtil.countBitsUpTo(nextIndex + 1, reader(bits)));
}
}
}
}
public void testPreviousBitSet() throws IOException {
int numIterations = atLeast(1000);
for (int i = 0; i < numIterations; i++) {
byte[] bits = buildRandomBits();
int numBytes = bits.length - 1;
int numBits = numBytes * Byte.SIZE;
// Verify previousBitSet with countBitsUpTo for all bit indexes.
for (int bitIndex = 0; bitIndex <= numBits; bitIndex++) {
int previousIndex = BitTableUtil.previousBitSet(bitIndex, reader(bits));
if (previousIndex == -1) {
assertEquals("No previous bit set, so expected bit count 0"
+ " (i=" + i + " bitIndex=" + bitIndex + ")",
0, BitTableUtil.countBitsUpTo(bitIndex, reader(bits)));
} else {
assertTrue("Expected previous bit set at previousIndex=" + previousIndex
+ " (i=" + i + " bitIndex=" + bitIndex + ")",
BitTableUtil.isBitSet(previousIndex, reader(bits)));
int bitCount = BitTableUtil.countBitsUpTo(Math.min(bitIndex + 1, numBits), reader(bits));
int expectedPreviousBitCount = bitIndex < numBits && BitTableUtil.isBitSet(bitIndex, reader(bits)) ?
bitCount - 1 : bitCount;
assertEquals("Previous bit set at previousIndex=" + previousIndex
+ " with current bitCount=" + bitCount
+ " so expected previousBitCount=" + expectedPreviousBitCount
+ " (i=" + i + " bitIndex=" + bitIndex + ")",
expectedPreviousBitCount, BitTableUtil.countBitsUpTo(previousIndex + 1, reader(bits)));
}
}
}
}
private byte[] buildRandomBits() {
byte[] bits = new byte[random().nextInt(24) + 2];
for (int i = 0; i < bits.length; i++) {
// Bias towards zeros which require special logic.
bits[i] = random().nextInt(4) == 0 ? 0 : (byte) random().nextInt();
}
return bits;
}
private static FST.BytesReader reader(byte[] bits) {
return new ByteArrayBytesReader(bits);
}
private static class ByteArrayBytesReader extends FST.BytesReader {
private final byte[] bits;
private int position;
ByteArrayBytesReader(byte[] bits) {
this.bits = bits;
}
@Override
public long getPosition() {
return position;
}
@Override
public void setPosition(long pos) {
position = (int) pos;
}
@Override
public boolean reversed() {
return false;
}
@Override
public byte readByte() {
return bits[position++];
}
@Override
public void readBytes(byte[] b, int offset, int len) {
throw new UnsupportedOperationException();
}
@Override
public void skipBytes(long numBytes) {
position += numBytes;
}
}
}

View File

@ -18,8 +18,10 @@ package org.apache.lucene.util.fst;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
@ -28,10 +30,13 @@ import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.zip.GZIPInputStream;
import org.apache.lucene.store.ByteArrayDataInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.InputStreamDataInput;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.LuceneTestCase;
@ -151,17 +156,23 @@ public class TestFSTDirectAddressing extends LuceneTestCase {
if (args.length < 2) {
throw new IllegalArgumentException("Missing argument");
}
if (args[0].equals("-countFSTArcs")) {
countFSTArcs(args[1]);
} else if (args[0].equals("-measureFSTOversizing")) {
measureFSTOversizing(args[1]);
} else {
throw new IllegalArgumentException("Invalid argument " + args[0]);
switch (args[0]) {
case "-countFSTArcs":
countFSTArcs(args[1]);
break;
case "-measureFSTOversizing":
measureFSTOversizing(args[1]);
break;
case "-recompileAndWalk":
recompileAndWalk(args[1]);
break;
default:
throw new IllegalArgumentException("Invalid argument " + args[0]);
}
}
private static void countFSTArcs(String FSTFilePath) throws IOException {
byte[] buf = Files.readAllBytes(Paths.get(FSTFilePath));
private static void countFSTArcs(String fstFilePath) throws IOException {
byte[] buf = Files.readAllBytes(Paths.get(fstFilePath));
DataInput in = new ByteArrayDataInput(buf);
FST<BytesRef> fst = new FST<>(in, ByteSequenceOutputs.getSingleton());
BytesRefFSTEnum<BytesRef> fstEnum = new BytesRefFSTEnum<>(fst);
@ -211,4 +222,62 @@ public class TestFSTDirectAddressing extends LuceneTestCase {
printStats(fstCompiler, ramBytesUsed, directAddressingMemoryIncreasePercent);
}
private static void recompileAndWalk(String fstFilePath) throws IOException {
try (InputStreamDataInput in = new InputStreamDataInput(newInputStream(Paths.get(fstFilePath)))) {
System.out.println("Reading FST");
long startTimeMs = System.currentTimeMillis();
FST<CharsRef> originalFst = new FST<>(in, CharSequenceOutputs.getSingleton());
long endTimeMs = System.currentTimeMillis();
System.out.println("time = " + (endTimeMs - startTimeMs) + " ms");
for (float oversizingFactor : List.of(0f, 0f, 0f, 1f, 1f, 1f)) {
System.out.println("\nFST construction (oversizingFactor=" + oversizingFactor + ")");
startTimeMs = System.currentTimeMillis();
FST<CharsRef> fst = recompile(originalFst, oversizingFactor);
endTimeMs = System.currentTimeMillis();
System.out.println("time = " + (endTimeMs - startTimeMs) + " ms");
System.out.println("FST RAM = " + fst.ramBytesUsed() + " B");
System.out.println("FST enum");
startTimeMs = System.currentTimeMillis();
walk(fst);
endTimeMs = System.currentTimeMillis();
System.out.println("time = " + (endTimeMs - startTimeMs) + " ms");
}
}
}
private static InputStream newInputStream(Path path) throws IOException {
InputStream in = Files.newInputStream(path);
String fileName = path.getFileName().toString();
if (fileName.endsWith("gz") || fileName.endsWith("zip")) {
in = new GZIPInputStream(in);
}
return in;
}
private static FST<CharsRef> recompile(FST<CharsRef> fst, float oversizingFactor) throws IOException {
FSTCompiler<CharsRef> fstCompiler = new FSTCompiler.Builder<>(FST.INPUT_TYPE.BYTE4, CharSequenceOutputs.getSingleton())
.directAddressingMaxOversizingFactor(oversizingFactor)
.build();
IntsRefFSTEnum<CharsRef> fstEnum = new IntsRefFSTEnum<>(fst);
IntsRefFSTEnum.InputOutput<CharsRef> inputOutput;
while ((inputOutput = fstEnum.next()) != null) {
fstCompiler.add(inputOutput.input, CharsRef.deepCopyOf(inputOutput.output));
}
return fstCompiler.compile();
}
private static int walk(FST<CharsRef> read) throws IOException {
IntsRefFSTEnum<CharsRef> fstEnum = new IntsRefFSTEnum<>(read);
IntsRefFSTEnum.InputOutput<CharsRef> inputOutput;
int terms = 0;
while ((inputOutput = fstEnum.next()) != null) {
terms += inputOutput.input.length;
terms += inputOutput.output.length;
}
return terms;
}
}