From 04f38dd288535eb2ffe12a85f7064b992d07d5b7 Mon Sep 17 00:00:00 2001 From: Dzung Bui Date: Tue, 10 Oct 2023 21:34:57 +0900 Subject: [PATCH] Move addNode to FSTCompiler (#12646) Currently FSTCompiler and FST have circular dependencies to each other. FSTCompiler creates an instance of FST, and on adding node (add(IntsRef input, T output)), it delegates to FST.addNode() and passes itself as a variable. This introduces a circular dependency and mixes up the FST constructing and traversing code. To make matter worse, this implies one can call FST.addNode with an arbitrary FSTCompiler (as it's a parameter), but in reality it should be the compiler which creates the FST. This commit moves the addNode method to FSTCompiler instead. Co-authored-by: Anh Dung Bui --- lucene/CHANGES.txt | 1 + .../java/org/apache/lucene/util/fst/FST.java | 420 +----------------- .../apache/lucene/util/fst/FSTCompiler.java | 405 ++++++++++++++++- .../org/apache/lucene/util/fst/NodeHash.java | 2 +- .../org/apache/lucene/util/fst/TestFSTs.java | 17 +- 5 files changed, 420 insertions(+), 425 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 2fdc2af5378..51c3b3a3e79 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -143,6 +143,7 @@ API Changes * GITHUB#12592: Add RandomAccessInput#length method to the RandomAccessInput interface. In addition deprecate ByteBuffersDataInput#size in favour of this new method. (Ignacio Vera) +* GITHUB#12646: Move FST#addNode to FSTCompiler to avoid a circular dependency between FST and FSTCompiler New Features --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/FST.java b/lucene/core/src/java/org/apache/lucene/util/fst/FST.java index fb356c2c9c7..a3b85cf5e31 100644 --- a/lucene/core/src/java/org/apache/lucene/util/fst/FST.java +++ b/lucene/core/src/java/org/apache/lucene/util/fst/FST.java @@ -33,7 +33,6 @@ import org.apache.lucene.store.DataOutput; import org.apache.lucene.store.InputStreamDataInput; import org.apache.lucene.store.OutputStreamDataOutput; import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Constants; import org.apache.lucene.util.RamUsageEstimator; @@ -72,17 +71,17 @@ public final class FST implements Accountable { private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(FST.class); - private static final int BIT_FINAL_ARC = 1 << 0; + static final int BIT_FINAL_ARC = 1 << 0; static final int BIT_LAST_ARC = 1 << 1; static final int BIT_TARGET_NEXT = 1 << 2; // TODO: we can free up a bit if we can nuke this: - private static final int BIT_STOP_NODE = 1 << 3; + static final int BIT_STOP_NODE = 1 << 3; /** This flag is set if the arc has an output. */ public static final int BIT_ARC_HAS_OUTPUT = 1 << 4; - private static final int BIT_ARC_HAS_FINAL_OUTPUT = 1 << 5; + static final int BIT_ARC_HAS_FINAL_OUTPUT = 1 << 5; /** Value of the arc flags to declare a node with fixed length arcs designed for binary search. */ // We use this as a marker because this one flag is illegal by itself. @@ -94,30 +93,6 @@ public final class FST implements Accountable { */ static final byte ARCS_FOR_DIRECT_ADDRESSING = 1 << 6; - /** - * @see #shouldExpandNodeWithFixedLengthArcs - */ - static final int FIXED_LENGTH_ARC_SHALLOW_DEPTH = 3; // 0 => only root node. - - /** - * @see #shouldExpandNodeWithFixedLengthArcs - */ - static final int FIXED_LENGTH_ARC_SHALLOW_NUM_ARCS = 5; - - /** - * @see #shouldExpandNodeWithFixedLengthArcs - */ - static final int FIXED_LENGTH_ARC_DEEP_NUM_ARCS = 10; - - /** - * Maximum oversizing factor allowed for direct addressing compared to binary search when - * expansion credits allow the oversizing. This factor prevents expansions that are obviously too - * costly even if there are sufficient credits. - * - * @see #shouldExpandNodeWithDirectAddressing - */ - private static final float DIRECT_ADDRESSING_MAX_OVERSIZE_WITH_CREDIT_FACTOR = 1.66f; - // Increment version to change it private static final String FILE_FORMAT_NAME = "FST"; private static final int VERSION_START = 6; @@ -126,11 +101,11 @@ public final class FST implements Accountable { // Never serialized; just used to represent the virtual // final node w/ no arcs: - private static final long FINAL_END_NODE = -1; + static final long FINAL_END_NODE = -1; // Never serialized; just used to represent the virtual // non-final node w/ no arcs: - private static final long NON_FINAL_END_NODE = 0; + static final long NON_FINAL_END_NODE = 0; /** If arc has this label then that arc is final/accepted */ public static final int END_LABEL = -1; @@ -603,19 +578,6 @@ public final class FST implements Accountable { } } - private void writeLabel(DataOutput out, int v) throws IOException { - assert v >= 0 : "v=" + v; - if (inputType == INPUT_TYPE.BYTE1) { - assert v <= 255 : "v=" + v; - out.writeByte((byte) v); - } else if (inputType == INPUT_TYPE.BYTE2) { - assert v <= 65535 : "v=" + v; - out.writeShort((short) v); - } else { - out.writeVInt(v); - } - } - /** Reads one BYTE1/2/4 label from the provided {@link DataInput}. */ public int readLabel(DataInput in) throws IOException { final int v; @@ -640,381 +602,11 @@ public final class FST implements Accountable { return arc.target() > 0; } - // serializes new node by appending its bytes to the end - // of the current byte[] - long addNode(FSTCompiler fstCompiler, FSTCompiler.UnCompiledNode nodeIn) - throws IOException { - T NO_OUTPUT = outputs.getNoOutput(); - - // System.out.println("FST.addNode pos=" + bytes.getPosition() + " numArcs=" + nodeIn.numArcs); - if (nodeIn.numArcs == 0) { - if (nodeIn.isFinal) { - return FINAL_END_NODE; - } else { - return NON_FINAL_END_NODE; - } - } - final long startAddress = fstCompiler.bytes.getPosition(); - // System.out.println(" startAddr=" + startAddress); - - final boolean doFixedLengthArcs = shouldExpandNodeWithFixedLengthArcs(fstCompiler, nodeIn); - if (doFixedLengthArcs) { - // System.out.println(" fixed length arcs"); - if (fstCompiler.numBytesPerArc.length < nodeIn.numArcs) { - fstCompiler.numBytesPerArc = new int[ArrayUtil.oversize(nodeIn.numArcs, Integer.BYTES)]; - fstCompiler.numLabelBytesPerArc = new int[fstCompiler.numBytesPerArc.length]; - } - } - - fstCompiler.arcCount += nodeIn.numArcs; - - final int lastArc = nodeIn.numArcs - 1; - - long lastArcStart = fstCompiler.bytes.getPosition(); - int maxBytesPerArc = 0; - int maxBytesPerArcWithoutLabel = 0; - for (int arcIdx = 0; arcIdx < nodeIn.numArcs; arcIdx++) { - final FSTCompiler.Arc arc = nodeIn.arcs[arcIdx]; - final FSTCompiler.CompiledNode target = (FSTCompiler.CompiledNode) arc.target; - int flags = 0; - // System.out.println(" arc " + arcIdx + " label=" + arc.label + " -> target=" + - // target.node); - - if (arcIdx == lastArc) { - flags += BIT_LAST_ARC; - } - - if (fstCompiler.lastFrozenNode == target.node && !doFixedLengthArcs) { - // TODO: for better perf (but more RAM used) we - // could avoid this except when arc is "near" the - // last arc: - flags += BIT_TARGET_NEXT; - } - - if (arc.isFinal) { - flags += BIT_FINAL_ARC; - if (arc.nextFinalOutput != NO_OUTPUT) { - flags += BIT_ARC_HAS_FINAL_OUTPUT; - } - } else { - assert arc.nextFinalOutput == NO_OUTPUT; - } - - boolean targetHasArcs = target.node > 0; - - if (!targetHasArcs) { - flags += BIT_STOP_NODE; - } - - if (arc.output != NO_OUTPUT) { - flags += BIT_ARC_HAS_OUTPUT; - } - - fstCompiler.bytes.writeByte((byte) flags); - long labelStart = fstCompiler.bytes.getPosition(); - writeLabel(fstCompiler.bytes, arc.label); - int numLabelBytes = (int) (fstCompiler.bytes.getPosition() - labelStart); - - // System.out.println(" write arc: label=" + (char) arc.label + " flags=" + flags + " - // target=" + target.node + " pos=" + bytes.getPosition() + " output=" + - // outputs.outputToString(arc.output)); - - if (arc.output != NO_OUTPUT) { - outputs.write(arc.output, fstCompiler.bytes); - // System.out.println(" write output"); - } - - if (arc.nextFinalOutput != NO_OUTPUT) { - // System.out.println(" write final output"); - outputs.writeFinalOutput(arc.nextFinalOutput, fstCompiler.bytes); - } - - if (targetHasArcs && (flags & BIT_TARGET_NEXT) == 0) { - assert target.node > 0; - // System.out.println(" write target"); - fstCompiler.bytes.writeVLong(target.node); - } - - // just write the arcs "like normal" on first pass, but record how many bytes each one took - // and max byte size: - if (doFixedLengthArcs) { - int numArcBytes = (int) (fstCompiler.bytes.getPosition() - lastArcStart); - fstCompiler.numBytesPerArc[arcIdx] = numArcBytes; - fstCompiler.numLabelBytesPerArc[arcIdx] = numLabelBytes; - lastArcStart = fstCompiler.bytes.getPosition(); - maxBytesPerArc = Math.max(maxBytesPerArc, numArcBytes); - maxBytesPerArcWithoutLabel = - Math.max(maxBytesPerArcWithoutLabel, numArcBytes - numLabelBytes); - // System.out.println(" arcBytes=" + numArcBytes + " labelBytes=" + numLabelBytes); - } - } - - // TODO: try to avoid wasteful cases: disable doFixedLengthArcs in that case - /* - * - * LUCENE-4682: what is a fair heuristic here? - * It could involve some of these: - * 1. how "busy" the node is: nodeIn.inputCount relative to frontier[0].inputCount? - * 2. how much binSearch saves over scan: nodeIn.numArcs - * 3. waste: numBytes vs numBytesExpanded - * - * the one below just looks at #3 - if (doFixedLengthArcs) { - // rough heuristic: make this 1.25 "waste factor" a parameter to the phd ctor???? - int numBytes = lastArcStart - startAddress; - int numBytesExpanded = maxBytesPerArc * nodeIn.numArcs; - if (numBytesExpanded > numBytes*1.25) { - doFixedLengthArcs = false; - } - } - */ - - if (doFixedLengthArcs) { - assert maxBytesPerArc > 0; - // 2nd pass just "expands" all arcs to take up a fixed byte size - - int labelRange = nodeIn.arcs[nodeIn.numArcs - 1].label - nodeIn.arcs[0].label + 1; - assert labelRange > 0; - if (shouldExpandNodeWithDirectAddressing( - fstCompiler, nodeIn, maxBytesPerArc, maxBytesPerArcWithoutLabel, labelRange)) { - writeNodeForDirectAddressing( - fstCompiler, nodeIn, startAddress, maxBytesPerArcWithoutLabel, labelRange); - fstCompiler.directAddressingNodeCount++; - } else { - writeNodeForBinarySearch(fstCompiler, nodeIn, startAddress, maxBytesPerArc); - fstCompiler.binarySearchNodeCount++; - } - } - - final long thisNodeAddress = fstCompiler.bytes.getPosition() - 1; - fstCompiler.bytes.reverse(startAddress, thisNodeAddress); - fstCompiler.nodeCount++; - return thisNodeAddress; - } - - /** - * Returns whether the given node should be expanded with fixed length arcs. Nodes will be - * expanded depending on their depth (distance from the root node) and their number of arcs. - * - *

Nodes with fixed length arcs use more space, because they encode all arcs with a fixed - * number of bytes, but they allow either binary search or direct addressing on the arcs (instead - * of linear scan) on lookup by arc label. - */ - private boolean shouldExpandNodeWithFixedLengthArcs( - FSTCompiler fstCompiler, FSTCompiler.UnCompiledNode node) { - return fstCompiler.allowFixedLengthArcs - && ((node.depth <= FIXED_LENGTH_ARC_SHALLOW_DEPTH - && node.numArcs >= FIXED_LENGTH_ARC_SHALLOW_NUM_ARCS) - || node.numArcs >= FIXED_LENGTH_ARC_DEEP_NUM_ARCS); - } - - /** - * Returns whether the given node should be expanded with direct addressing instead of binary - * search. - * - *

Prefer direct addressing for performance if it does not oversize binary search byte size too - * much, so that the arcs can be directly addressed by label. - * - * @see FSTCompiler#getDirectAddressingMaxOversizingFactor() - */ - private boolean shouldExpandNodeWithDirectAddressing( - FSTCompiler fstCompiler, - FSTCompiler.UnCompiledNode nodeIn, - int numBytesPerArc, - int maxBytesPerArcWithoutLabel, - int labelRange) { - // Anticipate precisely the size of the encodings. - int sizeForBinarySearch = numBytesPerArc * nodeIn.numArcs; - int sizeForDirectAddressing = - getNumPresenceBytes(labelRange) - + fstCompiler.numLabelBytesPerArc[0] - + maxBytesPerArcWithoutLabel * nodeIn.numArcs; - - // Determine the allowed oversize compared to binary search. - // This is defined by a parameter of FST Builder (default 1: no oversize). - int allowedOversize = - (int) (sizeForBinarySearch * fstCompiler.getDirectAddressingMaxOversizingFactor()); - int expansionCost = sizeForDirectAddressing - allowedOversize; - - // Select direct addressing if either: - // - Direct addressing size is smaller than binary search. - // In this case, increment the credit by the reduced size (to use it later). - // - Direct addressing size is larger than binary search, but the positive credit allows the - // oversizing. - // In this case, decrement the credit by the oversize. - // In addition, do not try to oversize to a clearly too large node size - // (this is the DIRECT_ADDRESSING_MAX_OVERSIZE_WITH_CREDIT_FACTOR parameter). - if (expansionCost <= 0 - || (fstCompiler.directAddressingExpansionCredit >= expansionCost - && sizeForDirectAddressing - <= allowedOversize * DIRECT_ADDRESSING_MAX_OVERSIZE_WITH_CREDIT_FACTOR)) { - fstCompiler.directAddressingExpansionCredit -= expansionCost; - return true; - } - return false; - } - - private void writeNodeForBinarySearch( - FSTCompiler fstCompiler, - FSTCompiler.UnCompiledNode nodeIn, - long startAddress, - int maxBytesPerArc) { - // Build the header in a buffer. - // It is a false/special arc which is in fact a node header with node flags followed by node - // metadata. - fstCompiler - .fixedLengthArcsBuffer - .resetPosition() - .writeByte(ARCS_FOR_BINARY_SEARCH) - .writeVInt(nodeIn.numArcs) - .writeVInt(maxBytesPerArc); - int headerLen = fstCompiler.fixedLengthArcsBuffer.getPosition(); - - // Expand the arcs in place, backwards. - long srcPos = fstCompiler.bytes.getPosition(); - long destPos = startAddress + headerLen + nodeIn.numArcs * (long) maxBytesPerArc; - assert destPos >= srcPos; - if (destPos > srcPos) { - fstCompiler.bytes.skipBytes((int) (destPos - srcPos)); - for (int arcIdx = nodeIn.numArcs - 1; arcIdx >= 0; arcIdx--) { - destPos -= maxBytesPerArc; - int arcLen = fstCompiler.numBytesPerArc[arcIdx]; - srcPos -= arcLen; - if (srcPos != destPos) { - assert destPos > srcPos - : "destPos=" - + destPos - + " srcPos=" - + srcPos - + " arcIdx=" - + arcIdx - + " maxBytesPerArc=" - + maxBytesPerArc - + " arcLen=" - + arcLen - + " nodeIn.numArcs=" - + nodeIn.numArcs; - fstCompiler.bytes.copyBytes(srcPos, destPos, arcLen); - } - } - } - - // Write the header. - fstCompiler.bytes.writeBytes( - startAddress, fstCompiler.fixedLengthArcsBuffer.getBytes(), 0, headerLen); - } - - private void writeNodeForDirectAddressing( - FSTCompiler fstCompiler, - FSTCompiler.UnCompiledNode nodeIn, - long startAddress, - int maxBytesPerArcWithoutLabel, - int labelRange) { - // Expand the arcs backwards in a buffer because we remove the labels. - // So the obtained arcs might occupy less space. This is the reason why this - // whole method is more complex. - // Drop the label bytes since we can infer the label based on the arc index, - // the presence bits, and the first label. Keep the first label. - int headerMaxLen = 11; - int numPresenceBytes = getNumPresenceBytes(labelRange); - long srcPos = fstCompiler.bytes.getPosition(); - int totalArcBytes = - fstCompiler.numLabelBytesPerArc[0] + nodeIn.numArcs * maxBytesPerArcWithoutLabel; - int bufferOffset = headerMaxLen + numPresenceBytes + totalArcBytes; - byte[] buffer = fstCompiler.fixedLengthArcsBuffer.ensureCapacity(bufferOffset).getBytes(); - // Copy the arcs to the buffer, dropping all labels except first one. - for (int arcIdx = nodeIn.numArcs - 1; arcIdx >= 0; arcIdx--) { - bufferOffset -= maxBytesPerArcWithoutLabel; - int srcArcLen = fstCompiler.numBytesPerArc[arcIdx]; - srcPos -= srcArcLen; - int labelLen = fstCompiler.numLabelBytesPerArc[arcIdx]; - // Copy the flags. - fstCompiler.bytes.copyBytes(srcPos, buffer, bufferOffset, 1); - // Skip the label, copy the remaining. - int remainingArcLen = srcArcLen - 1 - labelLen; - if (remainingArcLen != 0) { - fstCompiler.bytes.copyBytes( - srcPos + 1 + labelLen, buffer, bufferOffset + 1, remainingArcLen); - } - if (arcIdx == 0) { - // Copy the label of the first arc only. - bufferOffset -= labelLen; - fstCompiler.bytes.copyBytes(srcPos + 1, buffer, bufferOffset, labelLen); - } - } - assert bufferOffset == headerMaxLen + numPresenceBytes; - - // Build the header in the buffer. - // It is a false/special arc which is in fact a node header with node flags followed by node - // metadata. - fstCompiler - .fixedLengthArcsBuffer - .resetPosition() - .writeByte(ARCS_FOR_DIRECT_ADDRESSING) - .writeVInt(labelRange) // labelRange instead of numArcs. - .writeVInt( - maxBytesPerArcWithoutLabel); // maxBytesPerArcWithoutLabel instead of maxBytesPerArc. - int headerLen = fstCompiler.fixedLengthArcsBuffer.getPosition(); - - // Prepare the builder byte store. Enlarge or truncate if needed. - long nodeEnd = startAddress + headerLen + numPresenceBytes + totalArcBytes; - long currentPosition = fstCompiler.bytes.getPosition(); - if (nodeEnd >= currentPosition) { - fstCompiler.bytes.skipBytes((int) (nodeEnd - currentPosition)); - } else { - fstCompiler.bytes.truncate(nodeEnd); - } - assert fstCompiler.bytes.getPosition() == nodeEnd; - - // Write the header. - long writeOffset = startAddress; - fstCompiler.bytes.writeBytes( - writeOffset, fstCompiler.fixedLengthArcsBuffer.getBytes(), 0, headerLen); - writeOffset += headerLen; - - // Write the presence bits - writePresenceBits(fstCompiler, nodeIn, writeOffset, numPresenceBytes); - writeOffset += numPresenceBytes; - - // Write the first label and the arcs. - fstCompiler.bytes.writeBytes( - writeOffset, fstCompiler.fixedLengthArcsBuffer.getBytes(), bufferOffset, totalArcBytes); - } - - private void writePresenceBits( - FSTCompiler fstCompiler, - FSTCompiler.UnCompiledNode nodeIn, - long dest, - int numPresenceBytes) { - long bytePos = dest; - byte presenceBits = 1; // The first arc is always present. - int presenceIndex = 0; - int previousLabel = nodeIn.arcs[0].label; - for (int arcIdx = 1; arcIdx < nodeIn.numArcs; arcIdx++) { - int label = nodeIn.arcs[arcIdx].label; - assert label > previousLabel; - presenceIndex += label - previousLabel; - while (presenceIndex >= Byte.SIZE) { - fstCompiler.bytes.writeByte(bytePos++, presenceBits); - presenceBits = 0; - presenceIndex -= Byte.SIZE; - } - // Set the bit at presenceIndex to flag that the corresponding arc is present. - presenceBits |= 1 << presenceIndex; - previousLabel = label; - } - assert presenceIndex == (nodeIn.arcs[nodeIn.numArcs - 1].label - nodeIn.arcs[0].label) % 8; - assert presenceBits != 0; // The last byte is not 0. - assert (presenceBits & (1 << presenceIndex)) != 0; // The last arc is always present. - fstCompiler.bytes.writeByte(bytePos++, presenceBits); - assert bytePos - dest == numPresenceBytes; - } - /** * Gets the number of bytes required to flag the presence of each arc in the given label range, * one bit per arc. */ - private static int getNumPresenceBytes(int labelRange) { + static int getNumPresenceBytes(int labelRange) { assert labelRange >= 0; return (labelRange + 7) >> 3; } diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/FSTCompiler.java b/lucene/core/src/java/org/apache/lucene/util/fst/FSTCompiler.java index 550a57b3ec1..c968fa68db2 100644 --- a/lucene/core/src/java/org/apache/lucene/util/fst/FSTCompiler.java +++ b/lucene/core/src/java/org/apache/lucene/util/fst/FSTCompiler.java @@ -16,8 +16,21 @@ */ package org.apache.lucene.util.fst; +import static org.apache.lucene.util.fst.FST.ARCS_FOR_BINARY_SEARCH; +import static org.apache.lucene.util.fst.FST.ARCS_FOR_DIRECT_ADDRESSING; +import static org.apache.lucene.util.fst.FST.BIT_ARC_HAS_FINAL_OUTPUT; +import static org.apache.lucene.util.fst.FST.BIT_ARC_HAS_OUTPUT; +import static org.apache.lucene.util.fst.FST.BIT_FINAL_ARC; +import static org.apache.lucene.util.fst.FST.BIT_LAST_ARC; +import static org.apache.lucene.util.fst.FST.BIT_STOP_NODE; +import static org.apache.lucene.util.fst.FST.BIT_TARGET_NEXT; +import static org.apache.lucene.util.fst.FST.FINAL_END_NODE; +import static org.apache.lucene.util.fst.FST.NON_FINAL_END_NODE; +import static org.apache.lucene.util.fst.FST.getNumPresenceBytes; + import java.io.IOException; import org.apache.lucene.store.ByteArrayDataOutput; +import org.apache.lucene.store.DataOutput; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.IntsRef; import org.apache.lucene.util.IntsRefBuilder; @@ -46,6 +59,30 @@ public class FSTCompiler { static final float DIRECT_ADDRESSING_MAX_OVERSIZING_FACTOR = 1f; + /** + * @see #shouldExpandNodeWithFixedLengthArcs + */ + static final int FIXED_LENGTH_ARC_SHALLOW_DEPTH = 3; // 0 => only root node. + + /** + * @see #shouldExpandNodeWithFixedLengthArcs + */ + static final int FIXED_LENGTH_ARC_SHALLOW_NUM_ARCS = 5; + + /** + * @see #shouldExpandNodeWithFixedLengthArcs + */ + static final int FIXED_LENGTH_ARC_DEEP_NUM_ARCS = 10; + + /** + * Maximum oversizing factor allowed for direct addressing compared to binary search when + * expansion credits allow the oversizing. This factor prevents expansions that are obviously too + * costly even if there are sufficient credits. + * + * @see #shouldExpandNodeWithDirectAddressing + */ + private static final float DIRECT_ADDRESSING_MAX_OVERSIZE_WITH_CREDIT_FACTOR = 1.66f; + private final NodeHash dedupHash; final FST fst; private final T NO_OUTPUT; @@ -313,13 +350,13 @@ public class FSTCompiler { && (doShareNonSingletonNodes || nodeIn.numArcs <= 1) && tailLength <= shareMaxTailLength) { if (nodeIn.numArcs == 0) { - node = fst.addNode(this, nodeIn); + node = addNode(nodeIn); lastFrozenNode = node; } else { node = dedupHash.add(this, nodeIn); } } else { - node = fst.addNode(this, nodeIn); + node = addNode(nodeIn); } assert node != -2; @@ -337,6 +374,370 @@ public class FSTCompiler { return fn; } + // serializes new node by appending its bytes to the end + // of the current byte[] + long addNode(FSTCompiler.UnCompiledNode nodeIn) throws IOException { + T NO_OUTPUT = fst.outputs.getNoOutput(); + + // System.out.println("FST.addNode pos=" + bytes.getPosition() + " numArcs=" + nodeIn.numArcs); + if (nodeIn.numArcs == 0) { + if (nodeIn.isFinal) { + return FINAL_END_NODE; + } else { + return NON_FINAL_END_NODE; + } + } + final long startAddress = bytes.getPosition(); + // System.out.println(" startAddr=" + startAddress); + + final boolean doFixedLengthArcs = shouldExpandNodeWithFixedLengthArcs(nodeIn); + if (doFixedLengthArcs) { + // System.out.println(" fixed length arcs"); + if (numBytesPerArc.length < nodeIn.numArcs) { + numBytesPerArc = new int[ArrayUtil.oversize(nodeIn.numArcs, Integer.BYTES)]; + numLabelBytesPerArc = new int[numBytesPerArc.length]; + } + } + + arcCount += nodeIn.numArcs; + + final int lastArc = nodeIn.numArcs - 1; + + long lastArcStart = bytes.getPosition(); + int maxBytesPerArc = 0; + int maxBytesPerArcWithoutLabel = 0; + for (int arcIdx = 0; arcIdx < nodeIn.numArcs; arcIdx++) { + final FSTCompiler.Arc arc = nodeIn.arcs[arcIdx]; + final FSTCompiler.CompiledNode target = (FSTCompiler.CompiledNode) arc.target; + int flags = 0; + // System.out.println(" arc " + arcIdx + " label=" + arc.label + " -> target=" + + // target.node); + + if (arcIdx == lastArc) { + flags += BIT_LAST_ARC; + } + + if (lastFrozenNode == target.node && !doFixedLengthArcs) { + // TODO: for better perf (but more RAM used) we + // could avoid this except when arc is "near" the + // last arc: + flags += BIT_TARGET_NEXT; + } + + if (arc.isFinal) { + flags += BIT_FINAL_ARC; + if (arc.nextFinalOutput != NO_OUTPUT) { + flags += BIT_ARC_HAS_FINAL_OUTPUT; + } + } else { + assert arc.nextFinalOutput == NO_OUTPUT; + } + + boolean targetHasArcs = target.node > 0; + + if (!targetHasArcs) { + flags += BIT_STOP_NODE; + } + + if (arc.output != NO_OUTPUT) { + flags += BIT_ARC_HAS_OUTPUT; + } + + bytes.writeByte((byte) flags); + long labelStart = bytes.getPosition(); + writeLabel(bytes, arc.label); + int numLabelBytes = (int) (bytes.getPosition() - labelStart); + + // System.out.println(" write arc: label=" + (char) arc.label + " flags=" + flags + " + // target=" + target.node + " pos=" + bytes.getPosition() + " output=" + + // outputs.outputToString(arc.output)); + + if (arc.output != NO_OUTPUT) { + fst.outputs.write(arc.output, bytes); + // System.out.println(" write output"); + } + + if (arc.nextFinalOutput != NO_OUTPUT) { + // System.out.println(" write final output"); + fst.outputs.writeFinalOutput(arc.nextFinalOutput, bytes); + } + + if (targetHasArcs && (flags & BIT_TARGET_NEXT) == 0) { + assert target.node > 0; + // System.out.println(" write target"); + bytes.writeVLong(target.node); + } + + // just write the arcs "like normal" on first pass, but record how many bytes each one took + // and max byte size: + if (doFixedLengthArcs) { + int numArcBytes = (int) (bytes.getPosition() - lastArcStart); + numBytesPerArc[arcIdx] = numArcBytes; + numLabelBytesPerArc[arcIdx] = numLabelBytes; + lastArcStart = bytes.getPosition(); + maxBytesPerArc = Math.max(maxBytesPerArc, numArcBytes); + maxBytesPerArcWithoutLabel = + Math.max(maxBytesPerArcWithoutLabel, numArcBytes - numLabelBytes); + // System.out.println(" arcBytes=" + numArcBytes + " labelBytes=" + numLabelBytes); + } + } + + // TODO: try to avoid wasteful cases: disable doFixedLengthArcs in that case + /* + * + * LUCENE-4682: what is a fair heuristic here? + * It could involve some of these: + * 1. how "busy" the node is: nodeIn.inputCount relative to frontier[0].inputCount? + * 2. how much binSearch saves over scan: nodeIn.numArcs + * 3. waste: numBytes vs numBytesExpanded + * + * the one below just looks at #3 + if (doFixedLengthArcs) { + // rough heuristic: make this 1.25 "waste factor" a parameter to the phd ctor???? + int numBytes = lastArcStart - startAddress; + int numBytesExpanded = maxBytesPerArc * nodeIn.numArcs; + if (numBytesExpanded > numBytes*1.25) { + doFixedLengthArcs = false; + } + } + */ + + if (doFixedLengthArcs) { + assert maxBytesPerArc > 0; + // 2nd pass just "expands" all arcs to take up a fixed byte size + + int labelRange = nodeIn.arcs[nodeIn.numArcs - 1].label - nodeIn.arcs[0].label + 1; + assert labelRange > 0; + if (shouldExpandNodeWithDirectAddressing( + nodeIn, maxBytesPerArc, maxBytesPerArcWithoutLabel, labelRange)) { + writeNodeForDirectAddressing(nodeIn, startAddress, maxBytesPerArcWithoutLabel, labelRange); + directAddressingNodeCount++; + } else { + writeNodeForBinarySearch(nodeIn, startAddress, maxBytesPerArc); + binarySearchNodeCount++; + } + } + + final long thisNodeAddress = bytes.getPosition() - 1; + bytes.reverse(startAddress, thisNodeAddress); + nodeCount++; + return thisNodeAddress; + } + + private void writeLabel(DataOutput out, int v) throws IOException { + assert v >= 0 : "v=" + v; + if (fst.inputType == INPUT_TYPE.BYTE1) { + assert v <= 255 : "v=" + v; + out.writeByte((byte) v); + } else if (fst.inputType == INPUT_TYPE.BYTE2) { + assert v <= 65535 : "v=" + v; + out.writeShort((short) v); + } else { + out.writeVInt(v); + } + } + + /** + * Returns whether the given node should be expanded with fixed length arcs. Nodes will be + * expanded depending on their depth (distance from the root node) and their number of arcs. + * + *

Nodes with fixed length arcs use more space, because they encode all arcs with a fixed + * number of bytes, but they allow either binary search or direct addressing on the arcs (instead + * of linear scan) on lookup by arc label. + */ + private boolean shouldExpandNodeWithFixedLengthArcs(FSTCompiler.UnCompiledNode node) { + return allowFixedLengthArcs + && ((node.depth <= FIXED_LENGTH_ARC_SHALLOW_DEPTH + && node.numArcs >= FIXED_LENGTH_ARC_SHALLOW_NUM_ARCS) + || node.numArcs >= FIXED_LENGTH_ARC_DEEP_NUM_ARCS); + } + + /** + * Returns whether the given node should be expanded with direct addressing instead of binary + * search. + * + *

Prefer direct addressing for performance if it does not oversize binary search byte size too + * much, so that the arcs can be directly addressed by label. + * + * @see FSTCompiler#getDirectAddressingMaxOversizingFactor() + */ + private boolean shouldExpandNodeWithDirectAddressing( + FSTCompiler.UnCompiledNode nodeIn, + int numBytesPerArc, + int maxBytesPerArcWithoutLabel, + int labelRange) { + // Anticipate precisely the size of the encodings. + int sizeForBinarySearch = numBytesPerArc * nodeIn.numArcs; + int sizeForDirectAddressing = + getNumPresenceBytes(labelRange) + + numLabelBytesPerArc[0] + + maxBytesPerArcWithoutLabel * nodeIn.numArcs; + + // Determine the allowed oversize compared to binary search. + // This is defined by a parameter of FST Builder (default 1: no oversize). + int allowedOversize = (int) (sizeForBinarySearch * getDirectAddressingMaxOversizingFactor()); + int expansionCost = sizeForDirectAddressing - allowedOversize; + + // Select direct addressing if either: + // - Direct addressing size is smaller than binary search. + // In this case, increment the credit by the reduced size (to use it later). + // - Direct addressing size is larger than binary search, but the positive credit allows the + // oversizing. + // In this case, decrement the credit by the oversize. + // In addition, do not try to oversize to a clearly too large node size + // (this is the DIRECT_ADDRESSING_MAX_OVERSIZE_WITH_CREDIT_FACTOR parameter). + if (expansionCost <= 0 + || (directAddressingExpansionCredit >= expansionCost + && sizeForDirectAddressing + <= allowedOversize * DIRECT_ADDRESSING_MAX_OVERSIZE_WITH_CREDIT_FACTOR)) { + directAddressingExpansionCredit -= expansionCost; + return true; + } + return false; + } + + private void writeNodeForBinarySearch( + FSTCompiler.UnCompiledNode nodeIn, long startAddress, int maxBytesPerArc) { + // Build the header in a buffer. + // It is a false/special arc which is in fact a node header with node flags followed by node + // metadata. + fixedLengthArcsBuffer + .resetPosition() + .writeByte(ARCS_FOR_BINARY_SEARCH) + .writeVInt(nodeIn.numArcs) + .writeVInt(maxBytesPerArc); + int headerLen = fixedLengthArcsBuffer.getPosition(); + + // Expand the arcs in place, backwards. + long srcPos = bytes.getPosition(); + long destPos = startAddress + headerLen + nodeIn.numArcs * (long) maxBytesPerArc; + assert destPos >= srcPos; + if (destPos > srcPos) { + bytes.skipBytes((int) (destPos - srcPos)); + for (int arcIdx = nodeIn.numArcs - 1; arcIdx >= 0; arcIdx--) { + destPos -= maxBytesPerArc; + int arcLen = numBytesPerArc[arcIdx]; + srcPos -= arcLen; + if (srcPos != destPos) { + assert destPos > srcPos + : "destPos=" + + destPos + + " srcPos=" + + srcPos + + " arcIdx=" + + arcIdx + + " maxBytesPerArc=" + + maxBytesPerArc + + " arcLen=" + + arcLen + + " nodeIn.numArcs=" + + nodeIn.numArcs; + bytes.copyBytes(srcPos, destPos, arcLen); + } + } + } + + // Write the header. + bytes.writeBytes(startAddress, fixedLengthArcsBuffer.getBytes(), 0, headerLen); + } + + private void writeNodeForDirectAddressing( + FSTCompiler.UnCompiledNode nodeIn, + long startAddress, + int maxBytesPerArcWithoutLabel, + int labelRange) { + // Expand the arcs backwards in a buffer because we remove the labels. + // So the obtained arcs might occupy less space. This is the reason why this + // whole method is more complex. + // Drop the label bytes since we can infer the label based on the arc index, + // the presence bits, and the first label. Keep the first label. + int headerMaxLen = 11; + int numPresenceBytes = getNumPresenceBytes(labelRange); + long srcPos = bytes.getPosition(); + int totalArcBytes = numLabelBytesPerArc[0] + nodeIn.numArcs * maxBytesPerArcWithoutLabel; + int bufferOffset = headerMaxLen + numPresenceBytes + totalArcBytes; + byte[] buffer = fixedLengthArcsBuffer.ensureCapacity(bufferOffset).getBytes(); + // Copy the arcs to the buffer, dropping all labels except first one. + for (int arcIdx = nodeIn.numArcs - 1; arcIdx >= 0; arcIdx--) { + bufferOffset -= maxBytesPerArcWithoutLabel; + int srcArcLen = numBytesPerArc[arcIdx]; + srcPos -= srcArcLen; + int labelLen = numLabelBytesPerArc[arcIdx]; + // Copy the flags. + bytes.copyBytes(srcPos, buffer, bufferOffset, 1); + // Skip the label, copy the remaining. + int remainingArcLen = srcArcLen - 1 - labelLen; + if (remainingArcLen != 0) { + bytes.copyBytes(srcPos + 1 + labelLen, buffer, bufferOffset + 1, remainingArcLen); + } + if (arcIdx == 0) { + // Copy the label of the first arc only. + bufferOffset -= labelLen; + bytes.copyBytes(srcPos + 1, buffer, bufferOffset, labelLen); + } + } + assert bufferOffset == headerMaxLen + numPresenceBytes; + + // Build the header in the buffer. + // It is a false/special arc which is in fact a node header with node flags followed by node + // metadata. + fixedLengthArcsBuffer + .resetPosition() + .writeByte(ARCS_FOR_DIRECT_ADDRESSING) + .writeVInt(labelRange) // labelRange instead of numArcs. + .writeVInt( + maxBytesPerArcWithoutLabel); // maxBytesPerArcWithoutLabel instead of maxBytesPerArc. + int headerLen = fixedLengthArcsBuffer.getPosition(); + + // Prepare the builder byte store. Enlarge or truncate if needed. + long nodeEnd = startAddress + headerLen + numPresenceBytes + totalArcBytes; + long currentPosition = bytes.getPosition(); + if (nodeEnd >= currentPosition) { + bytes.skipBytes((int) (nodeEnd - currentPosition)); + } else { + bytes.truncate(nodeEnd); + } + assert bytes.getPosition() == nodeEnd; + + // Write the header. + long writeOffset = startAddress; + bytes.writeBytes(writeOffset, fixedLengthArcsBuffer.getBytes(), 0, headerLen); + writeOffset += headerLen; + + // Write the presence bits + writePresenceBits(nodeIn, writeOffset, numPresenceBytes); + writeOffset += numPresenceBytes; + + // Write the first label and the arcs. + bytes.writeBytes(writeOffset, fixedLengthArcsBuffer.getBytes(), bufferOffset, totalArcBytes); + } + + private void writePresenceBits( + FSTCompiler.UnCompiledNode nodeIn, long dest, int numPresenceBytes) { + long bytePos = dest; + byte presenceBits = 1; // The first arc is always present. + int presenceIndex = 0; + int previousLabel = nodeIn.arcs[0].label; + for (int arcIdx = 1; arcIdx < nodeIn.numArcs; arcIdx++) { + int label = nodeIn.arcs[arcIdx].label; + assert label > previousLabel; + presenceIndex += label - previousLabel; + while (presenceIndex >= Byte.SIZE) { + bytes.writeByte(bytePos++, presenceBits); + presenceBits = 0; + presenceIndex -= Byte.SIZE; + } + // Set the bit at presenceIndex to flag that the corresponding arc is present. + presenceBits |= 1 << presenceIndex; + previousLabel = label; + } + assert presenceIndex == (nodeIn.arcs[nodeIn.numArcs - 1].label - nodeIn.arcs[0].label) % 8; + assert presenceBits != 0; // The last byte is not 0. + assert (presenceBits & (1 << presenceIndex)) != 0; // The last arc is always present. + bytes.writeByte(bytePos++, presenceBits); + assert bytePos - dest == numPresenceBytes; + } + private void freezeTail(int prefixLenPlus1) throws IOException { // System.out.println(" compileTail " + prefixLenPlus1); final int downTo = Math.max(1, prefixLenPlus1); diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/NodeHash.java b/lucene/core/src/java/org/apache/lucene/util/fst/NodeHash.java index 92a878b1fbb..144e4c5564a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/fst/NodeHash.java +++ b/lucene/core/src/java/org/apache/lucene/util/fst/NodeHash.java @@ -139,7 +139,7 @@ final class NodeHash { final long v = table.get(pos); if (v == 0) { // freeze & add - final long node = fst.addNode(fstCompiler, nodeIn); + final long node = fstCompiler.addNode(nodeIn); // System.out.println(" now freeze node=" + node); assert hash(node) == h : "frozenHash=" + hash(node) + " vs h=" + h; count++; diff --git a/lucene/core/src/test/org/apache/lucene/util/fst/TestFSTs.java b/lucene/core/src/test/org/apache/lucene/util/fst/TestFSTs.java index fcdacbd5b18..414f61d9dff 100644 --- a/lucene/core/src/test/org/apache/lucene/util/fst/TestFSTs.java +++ b/lucene/core/src/test/org/apache/lucene/util/fst/TestFSTs.java @@ -1127,9 +1127,9 @@ public class TestFSTs extends LuceneTestCase { int children = verifyStateAndBelow(fst, new FST.Arc<>().copyFrom(arc), depth + 1); assertEquals( - (depth <= FST.FIXED_LENGTH_ARC_SHALLOW_DEPTH - && children >= FST.FIXED_LENGTH_ARC_SHALLOW_NUM_ARCS) - || children >= FST.FIXED_LENGTH_ARC_DEEP_NUM_ARCS, + (depth <= FSTCompiler.FIXED_LENGTH_ARC_SHALLOW_DEPTH + && children >= FSTCompiler.FIXED_LENGTH_ARC_SHALLOW_NUM_ARCS) + || children >= FSTCompiler.FIXED_LENGTH_ARC_DEEP_NUM_ARCS, expanded); if (arc.isLast()) break; } @@ -1141,8 +1141,9 @@ public class TestFSTs extends LuceneTestCase { } // Sanity check. - assertTrue(FST.FIXED_LENGTH_ARC_SHALLOW_NUM_ARCS < FST.FIXED_LENGTH_ARC_DEEP_NUM_ARCS); - assertTrue(FST.FIXED_LENGTH_ARC_SHALLOW_DEPTH >= 0); + assertTrue( + FSTCompiler.FIXED_LENGTH_ARC_SHALLOW_NUM_ARCS < FSTCompiler.FIXED_LENGTH_ARC_DEEP_NUM_ARCS); + assertTrue(FSTCompiler.FIXED_LENGTH_ARC_SHALLOW_DEPTH >= 0); SyntheticData s = new SyntheticData(); @@ -1210,7 +1211,7 @@ public class TestFSTs extends LuceneTestCase { node.isFinal = true; rootNode.addArc('a', node); final FSTCompiler.CompiledNode frozen = new FSTCompiler.CompiledNode(); - frozen.node = fst.addNode(fstCompiler, node); + frozen.node = fstCompiler.addNode(node); rootNode.arcs[0].nextFinalOutput = 17L; rootNode.arcs[0].isFinal = true; rootNode.arcs[0].output = nothing; @@ -1223,13 +1224,13 @@ public class TestFSTs extends LuceneTestCase { new FSTCompiler.UnCompiledNode<>(fstCompiler, 0); rootNode.addArc('b', node); final FSTCompiler.CompiledNode frozen = new FSTCompiler.CompiledNode(); - frozen.node = fst.addNode(fstCompiler, node); + frozen.node = fstCompiler.addNode(node); rootNode.arcs[1].nextFinalOutput = nothing; rootNode.arcs[1].output = 42L; rootNode.arcs[1].target = frozen; } - fst.finish(fst.addNode(fstCompiler, rootNode)); + fst.finish(fstCompiler.addNode(rootNode)); StringWriter w = new StringWriter(); // Writer w = new OutputStreamWriter(new FileOutputStream("/x/tmp3/out.dot"));