Use value-based LRU cache in NodeHash (#12738)

* Use value-based LRU cache in NodeHash (#12714)

* tidy code

* Add a nocommit about OffsetAndLength

* Fix the readBytes method

* Use List<byte[]> instead of ByteBlockPool

* Move nodesEqual to PagedGrowableHash

* Add generic type

* Fix the count variable

* Fix the RAM usage measurement

* Use PagedGrowableWriter instead of HashMap

* Remove unused generic type

* Update the ramBytesUsed formula

* Retain the FSTCompiler.addNode signature

* Switch back to ByteBlockPool

* Remove the unnecessary assertion

* Remove fstHashAddress

* Add some javadoc

* Fix the address offset when reading from fallback table

* tidy code

* Address comments

* Add assertions
This commit is contained in:
Dzung Bui 2023-11-05 00:24:49 +09:00 committed by GitHub
parent 5ef651fc4c
commit b8a9b0ae29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 368 additions and 142 deletions

View File

@ -38,6 +38,8 @@ public final class ByteBlockPool implements Accountable {
/** Abstract class for allocating and freeing byte blocks. */
public abstract static class Allocator {
// TODO: ByteBlockPool assume the blockSize is always {@link BYTE_BLOCK_SIZE}, but this class
// allow arbitrary value of blockSize. We should make them consistent.
protected final int blockSize;
protected Allocator(int blockSize) {
@ -215,19 +217,38 @@ public final class ByteBlockPool implements Accountable {
/** Appends the bytes in the provided {@link BytesRef} at the current position. */
public void append(final BytesRef bytes) {
int bytesLeft = bytes.length;
int offset = bytes.offset;
append(bytes.bytes, bytes.offset, bytes.length);
}
/**
* Append the provided byte array at the current position.
*
* @param bytes the byte array to write
*/
public void append(final byte[] bytes) {
append(bytes, 0, bytes.length);
}
/**
* Append some portion of the provided byte array at the current position.
*
* @param bytes the byte array to write
* @param offset the offset of the byte array
* @param length the number of bytes to write
*/
public void append(final byte[] bytes, int offset, int length) {
int bytesLeft = length;
while (bytesLeft > 0) {
int bufferLeft = BYTE_BLOCK_SIZE - byteUpto;
if (bytesLeft < bufferLeft) {
// fits within current buffer
System.arraycopy(bytes.bytes, offset, buffer, byteUpto, bytesLeft);
System.arraycopy(bytes, offset, buffer, byteUpto, bytesLeft);
byteUpto += bytesLeft;
break;
} else {
// fill up this buffer and move to next one
if (bufferLeft > 0) {
System.arraycopy(bytes.bytes, offset, buffer, byteUpto, bufferLeft);
System.arraycopy(bytes, offset, buffer, byteUpto, bufferLeft);
}
nextBuffer();
bytesLeft -= bufferLeft;
@ -256,6 +277,18 @@ public final class ByteBlockPool implements Accountable {
}
}
/**
* Read a single byte at the given offset
*
* @param offset the offset to read
* @return the byte
*/
public byte readByte(final long offset) {
int bufferIndex = (int) (offset >> BYTE_BLOCK_SHIFT);
int pos = (int) (offset & BYTE_BLOCK_MASK);
return buffers[bufferIndex][pos];
}
@Override
public long ramBytesUsed() {
long size = BASE_RAM_BYTES;
@ -269,4 +302,9 @@ public final class ByteBlockPool implements Accountable {
}
return size;
}
/** the current position (in absolute value) of this byte pool */
public long getPosition() {
return bufferUpto * allocator.blockSize + byteUpto;
}
}

View File

@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.fst;
import java.io.IOException;
import org.apache.lucene.util.ByteBlockPool;
/** Reads in reverse from a ByteBlockPool. */
final class ByteBlockPoolReverseBytesReader extends FST.BytesReader {
private final ByteBlockPool buf;
// the difference between the FST node address and the hash table copied node address
private long posDelta;
private long pos;
public ByteBlockPoolReverseBytesReader(ByteBlockPool buf) {
this.buf = buf;
}
@Override
public byte readByte() {
return buf.readByte(pos--);
}
@Override
public void readBytes(byte[] b, int offset, int len) {
for (int i = 0; i < len; i++) {
b[offset + i] = buf.readByte(pos--);
}
}
@Override
public void skipBytes(long numBytes) throws IOException {
pos -= numBytes;
}
@Override
public long getPosition() {
return pos + posDelta;
}
@Override
public void setPosition(long pos) {
this.pos = pos - posDelta;
}
@Override
public boolean reversed() {
return true;
}
public void setPosDelta(long posDelta) {
this.posDelta = posDelta;
}
}

View File

@ -444,11 +444,7 @@ class BytesStore extends DataOutput implements FSTReader {
@Override
public FST.BytesReader getReverseBytesReader() {
return getReverseReader(true);
}
FST.BytesReader getReverseReader(boolean allowSingle) {
if (allowSingle && blocks.size() == 1) {
if (blocks.size() == 1) {
return new ReverseBytesReader(blocks.get(0));
}
return new FST.BytesReader() {

View File

@ -145,7 +145,7 @@ public class FSTCompiler<T> {
if (suffixRAMLimitMB < 0) {
throw new IllegalArgumentException("ramLimitMB must be >= 0; got: " + suffixRAMLimitMB);
} else if (suffixRAMLimitMB > 0) {
dedupHash = new NodeHash<>(this, suffixRAMLimitMB, bytes.getReverseReader(false));
dedupHash = new NodeHash<>(this, suffixRAMLimitMB);
} else {
dedupHash = null;
}

View File

@ -17,6 +17,7 @@
package org.apache.lucene.util.fst;
import java.io.IOException;
import org.apache.lucene.util.ByteBlockPool;
import org.apache.lucene.util.packed.PackedInts;
import org.apache.lucene.util.packed.PagedGrowableWriter;
@ -49,14 +50,17 @@ final class NodeHash<T> {
private final FSTCompiler<T> fstCompiler;
private final FST.Arc<T> scratchArc = new FST.Arc<>();
private final FST.BytesReader in;
// store the last fallback table node length in getFallback()
private int lastFallbackNodeLength;
// store the last fallback table hashtable slot in getFallback()
private long lastFallbackHashSlot;
/**
* ramLimitMB is the max RAM we can use for recording suffixes. If we hit this limit, the least
* recently used suffixes are discarded, and the FST is no longer minimalI. Still, larger
* ramLimitMB will make the FST smaller (closer to minimal).
*/
public NodeHash(FSTCompiler<T> fstCompiler, double ramLimitMB, FST.BytesReader in) {
public NodeHash(FSTCompiler<T> fstCompiler, double ramLimitMB) {
if (ramLimitMB <= 0) {
throw new IllegalArgumentException("ramLimitMB must be > 0; got: " + ramLimitMB);
}
@ -70,28 +74,35 @@ final class NodeHash<T> {
primaryTable = new PagedGrowableHash();
this.fstCompiler = fstCompiler;
this.in = in;
}
private long getFallback(FSTCompiler.UnCompiledNode<T> nodeIn, long hash) throws IOException {
this.lastFallbackNodeLength = -1;
this.lastFallbackHashSlot = -1;
if (fallbackTable == null) {
// no fallback yet (primary table is not yet large enough to swap)
return 0;
}
long pos = hash & fallbackTable.mask;
long hashSlot = hash & fallbackTable.mask;
int c = 0;
while (true) {
long node = fallbackTable.get(pos);
if (node == 0) {
long nodeAddress = fallbackTable.getNodeAddress(hashSlot);
if (nodeAddress == 0) {
// not found
return 0;
} else if (nodesEqual(nodeIn, node)) {
} else {
int length = fallbackTable.nodesEqual(nodeIn, nodeAddress, hashSlot);
if (length != -1) {
// store the node length for further use
this.lastFallbackNodeLength = length;
this.lastFallbackHashSlot = hashSlot;
// frozen version of this node is already here
return node;
return nodeAddress;
}
}
// quadratic probe (but is it, really?)
pos = (pos + (++c)) & fallbackTable.mask;
hashSlot = (hashSlot + (++c)) & fallbackTable.mask;
}
}
@ -99,36 +110,60 @@ final class NodeHash<T> {
long hash = hash(nodeIn);
long pos = hash & primaryTable.mask;
long hashSlot = hash & primaryTable.mask;
int c = 0;
while (true) {
long node = primaryTable.get(pos);
if (node == 0) {
long nodeAddress = primaryTable.getNodeAddress(hashSlot);
if (nodeAddress == 0) {
// node is not in primary table; is it in fallback table?
node = getFallback(nodeIn, hash);
if (node != 0) {
nodeAddress = getFallback(nodeIn, hash);
if (nodeAddress != 0) {
assert lastFallbackHashSlot != -1 && lastFallbackNodeLength != -1;
// it was already in fallback -- promote to primary
primaryTable.set(pos, node);
// TODO: Copy directly between 2 ByteBlockPool to avoid double-copy
primaryTable.setNode(
hashSlot,
nodeAddress,
fallbackTable.getBytes(lastFallbackHashSlot, lastFallbackNodeLength));
} else {
// not in fallback either -- freeze & add the incoming node
long startAddress = fstCompiler.bytes.getPosition();
// freeze & add
node = fstCompiler.addNode(nodeIn);
nodeAddress = fstCompiler.addNode(nodeIn);
// TODO: Write the bytes directly from BytesStore
// we use 0 as empty marker in hash table, so it better be impossible to get a frozen node
// at 0:
assert node != 0;
assert nodeAddress != FST.FINAL_END_NODE && nodeAddress != FST.NON_FINAL_END_NODE;
byte[] buf = new byte[Math.toIntExact(nodeAddress - startAddress + 1)];
fstCompiler.bytes.copyBytes(startAddress, buf, 0, buf.length);
primaryTable.setNode(hashSlot, nodeAddress, buf);
// confirm frozen hash and unfrozen hash are the same
assert hash(node) == hash : "mismatch frozenHash=" + hash(node) + " vs hash=" + hash;
primaryTable.set(pos, node);
assert primaryTable.hash(nodeAddress, hashSlot) == hash
: "mismatch frozenHash="
+ primaryTable.hash(nodeAddress, hashSlot)
+ " vs hash="
+ hash;
}
// how many bytes would be used if we had "perfect" hashing:
long ramBytesUsed = primaryTable.count * PackedInts.bitsRequired(node) / 8;
// - x2 for fstNodeAddress for FST node address
// - x2 for copiedNodeAddress for copied node address
// - the bytes copied out FST to the hashtable copiedNodes
// each account for approximate hash table overhead halfway between 33.3% and 66.6%
// note that some of the copiedNodes are shared between fallback and primary tables so this
// computation is pessimistic
long copiedBytes = primaryTable.copiedNodes.getPosition();
long ramBytesUsed =
primaryTable.count * 2 * PackedInts.bitsRequired(nodeAddress) / 8
+ primaryTable.count * 2 * PackedInts.bitsRequired(copiedBytes) / 8
+ copiedBytes;
// NOTE: we could instead use the more precise RAM used, but this leads to unpredictable
// quantized behavior due to 2X rehashing where for large ranges of the RAM limit, the
@ -138,30 +173,29 @@ final class NodeHash<T> {
// in smaller FSTs, even if the precise RAM used is not always under the limit.
// divide limit by 2 because fallback gets half the RAM and primary gets the other half
// divide by 2 again to account for approximate hash table overhead halfway between 33.3%
// and 66.7% occupancy = 50%
if (ramBytesUsed >= ramLimitBytes / (2 * 2)) {
if (ramBytesUsed >= ramLimitBytes / 2) {
// time to fallback -- fallback is now used read-only to promote a node (suffix) to
// primary if we encounter it again
fallbackTable = primaryTable;
// size primary table the same size to reduce rehash cost
// TODO: we could clear & reuse the previous fallbackTable, instead of allocating a new
// to reduce GC load
primaryTable = new PagedGrowableHash(node, Math.max(16, primaryTable.entries.size()));
} else if (primaryTable.count > primaryTable.entries.size() * (2f / 3)) {
primaryTable =
new PagedGrowableHash(nodeAddress, Math.max(16, primaryTable.fstNodeAddress.size()));
} else if (primaryTable.count > primaryTable.fstNodeAddress.size() * (2f / 3)) {
// rehash at 2/3 occupancy
primaryTable.rehash(node);
primaryTable.rehash(nodeAddress);
}
return node;
return nodeAddress;
} else if (nodesEqual(nodeIn, node)) {
} else if (primaryTable.nodesEqual(nodeIn, nodeAddress, hashSlot) != -1) {
// same node (in frozen form) is already in primary table
return node;
return nodeAddress;
}
// quadratic probe (but is it, really?)
pos = (pos + (++c)) & primaryTable.mask;
hashSlot = (hashSlot + (++c)) & primaryTable.mask;
}
}
@ -186,13 +220,145 @@ final class NodeHash<T> {
return h;
}
/** Inner class because it needs access to hash function and FST bytes. */
private class PagedGrowableHash {
// storing the FST node address where the position is the masked hash of the node arcs
private PagedGrowableWriter fstNodeAddress;
// storing the local copiedNodes address in the same position as fstNodeAddress
// here we are effectively storing a Map<Long, Long> from the FST node address to copiedNodes
// address
private PagedGrowableWriter copiedNodeAddress;
private long count;
private long mask;
// storing the byte slice from the FST for nodes we added to the hash so that we don't need to
// look up from the FST itself, so the FST bytes can stream directly to disk as append-only
// writes.
// each node will be written subsequently
private final ByteBlockPool copiedNodes;
// the {@link FST.BytesReader} to read from copiedNodes. we use this when computing a frozen
// node hash
// or comparing if a frozen and unfrozen nodes are equal
private final ByteBlockPoolReverseBytesReader bytesReader;
// 256K blocks, but note that the final block is sized only as needed so it won't use the full
// block size when just a few elements were written to it
private static final int BLOCK_SIZE_BYTES = 1 << 18;
public PagedGrowableHash() {
fstNodeAddress = new PagedGrowableWriter(16, BLOCK_SIZE_BYTES, 8, PackedInts.COMPACT);
copiedNodeAddress = new PagedGrowableWriter(16, BLOCK_SIZE_BYTES, 8, PackedInts.COMPACT);
mask = 15;
copiedNodes = new ByteBlockPool(new ByteBlockPool.DirectAllocator());
bytesReader = new ByteBlockPoolReverseBytesReader(copiedNodes);
}
public PagedGrowableHash(long lastNodeAddress, long size) {
fstNodeAddress =
new PagedGrowableWriter(
size, BLOCK_SIZE_BYTES, PackedInts.bitsRequired(lastNodeAddress), PackedInts.COMPACT);
copiedNodeAddress = new PagedGrowableWriter(size, BLOCK_SIZE_BYTES, 8, PackedInts.COMPACT);
mask = size - 1;
assert (mask & size) == 0 : "size must be a power-of-2; got size=" + size + " mask=" + mask;
copiedNodes = new ByteBlockPool(new ByteBlockPool.DirectAllocator());
bytesReader = new ByteBlockPoolReverseBytesReader(copiedNodes);
}
/**
* Get the copied bytes at the provided hash slot
*
* @param hashSlot the hash slot to read from
* @param length the number of bytes to read
* @return the copied byte array
*/
public byte[] getBytes(long hashSlot, int length) {
long address = copiedNodeAddress.get(hashSlot);
assert address - length + 1 >= 0;
byte[] buf = new byte[length];
copiedNodes.readBytes(address - length + 1, buf, 0, length);
return buf;
}
/**
* Get the node address from the provided hash slot
*
* @param hashSlot the hash slot to read
* @return the node address
*/
public long getNodeAddress(long hashSlot) {
return fstNodeAddress.get(hashSlot);
}
/**
* Set the node address and bytes from the provided hash slot
*
* @param hashSlot the hash slot to write to
* @param nodeAddress the node address
* @param bytes the node bytes to be copied
*/
public void setNode(long hashSlot, long nodeAddress, byte[] bytes) {
assert fstNodeAddress.get(hashSlot) == 0;
fstNodeAddress.set(hashSlot, nodeAddress);
count++;
copiedNodes.append(bytes);
// write the offset, which points to the last byte of the node we copied since we later read
// this node in reverse
assert copiedNodeAddress.get(hashSlot) == 0;
copiedNodeAddress.set(hashSlot, copiedNodes.getPosition() - 1);
}
private void rehash(long lastNodeAddress) throws IOException {
// TODO: https://github.com/apache/lucene/issues/12744
// should we always use a small startBitsPerValue here (e.g 8) instead base off of
// lastNodeAddress?
// double hash table size on each rehash
long newSize = 2 * fstNodeAddress.size();
PagedGrowableWriter newCopiedNodeAddress =
new PagedGrowableWriter(
newSize,
BLOCK_SIZE_BYTES,
PackedInts.bitsRequired(copiedNodes.getPosition()),
PackedInts.COMPACT);
PagedGrowableWriter newFSTNodeAddress =
new PagedGrowableWriter(
newSize,
BLOCK_SIZE_BYTES,
PackedInts.bitsRequired(lastNodeAddress),
PackedInts.COMPACT);
long newMask = newFSTNodeAddress.size() - 1;
for (long idx = 0; idx < fstNodeAddress.size(); idx++) {
long address = fstNodeAddress.get(idx);
if (address != 0) {
long hashSlot = hash(address, idx) & newMask;
int c = 0;
while (true) {
if (newFSTNodeAddress.get(hashSlot) == 0) {
newFSTNodeAddress.set(hashSlot, address);
newCopiedNodeAddress.set(hashSlot, copiedNodeAddress.get(idx));
break;
}
// quadratic probe
hashSlot = (hashSlot + (++c)) & newMask;
}
}
}
mask = newMask;
fstNodeAddress = newFSTNodeAddress;
copiedNodeAddress = newCopiedNodeAddress;
}
// hash code for a frozen node. this must precisely match the hash computation of an unfrozen
// node!
private long hash(long node) throws IOException {
private long hash(long nodeAddress, long hashSlot) throws IOException {
FST.BytesReader in = getBytesReader(nodeAddress, hashSlot);
final int PRIME = 31;
long h = 0;
fstCompiler.fst.readFirstRealTargetArc(node, scratchArc, in);
fstCompiler.fst.readFirstRealTargetArc(nodeAddress, scratchArc, in);
while (true) {
h = PRIME * h + scratchArc.label();
h = PRIME * h + (int) (scratchArc.target() ^ (scratchArc.target() >> 32));
@ -211,10 +377,15 @@ final class NodeHash<T> {
}
/**
* Compares an unfrozen node (UnCompiledNode) with a frozen node at byte location address (long),
* returning true if they are equal.
* Compares an unfrozen node (UnCompiledNode) with a frozen node at byte location address
* (long), returning the node length if the two nodes are equals, or -1 otherwise
*
* <p>The node length will be used to promote the node from the fallback table to the primary
* table
*/
private boolean nodesEqual(FSTCompiler.UnCompiledNode<T> node, long address) throws IOException {
private int nodesEqual(FSTCompiler.UnCompiledNode<T> node, long address, long hashSlot)
throws IOException {
FST.BytesReader in = getBytesReader(address, hashSlot);
fstCompiler.fst.readFirstRealTargetArc(address, scratchArc, in);
// fail fast for a node with fixed length arcs
@ -226,7 +397,7 @@ final class NodeHash<T> {
case FST.ARCS_FOR_BINARY_SEARCH:
// sparse
if (node.numArcs != scratchArc.numArcs()) {
return false;
return -1;
}
break;
case FST.ARCS_FOR_DIRECT_ADDRESSING:
@ -234,7 +405,7 @@ final class NodeHash<T> {
// not actually be arcs), and the number of arcs
if ((node.arcs[node.numArcs - 1].label - node.arcs[0].label + 1) != scratchArc.numArcs()
|| node.numArcs != FST.Arc.BitTable.countBits(scratchArc, in)) {
return false;
return -1;
}
break;
default:
@ -250,14 +421,15 @@ final class NodeHash<T> {
|| ((FSTCompiler.CompiledNode) arc.target).node != scratchArc.target()
|| arc.nextFinalOutput.equals(scratchArc.nextFinalOutput()) == false
|| arc.isFinal != scratchArc.isFinal()) {
return false;
return -1;
}
if (scratchArc.isLast()) {
if (arcUpto == node.numArcs - 1) {
return true;
// position is 1 index past the starting address, as we are reading in backward
return Math.toIntExact(address - in.getPosition());
} else {
return false;
return -1;
}
}
@ -266,69 +438,15 @@ final class NodeHash<T> {
// unfrozen node has fewer arcs than frozen node
return false;
return -1;
}
/** Inner class because it needs access to hash function and FST bytes. */
private class PagedGrowableHash {
private PagedGrowableWriter entries;
private long count;
private long mask;
// 256K blocks, but note that the final block is sized only as needed so it won't use the full
// block size when just a few elements were written to it
private static final int BLOCK_SIZE_BYTES = 1 << 18;
public PagedGrowableHash() {
entries = new PagedGrowableWriter(16, BLOCK_SIZE_BYTES, 8, PackedInts.COMPACT);
mask = 15;
}
public PagedGrowableHash(long lastNodeAddress, long size) {
entries =
new PagedGrowableWriter(
size, BLOCK_SIZE_BYTES, PackedInts.bitsRequired(lastNodeAddress), PackedInts.COMPACT);
mask = size - 1;
assert (mask & size) == 0 : "size must be a power-of-2; got size=" + size + " mask=" + mask;
}
public long get(long index) {
return entries.get(index);
}
public void set(long index, long pointer) throws IOException {
entries.set(index, pointer);
count++;
}
private void rehash(long lastNodeAddress) throws IOException {
// double hash table size on each rehash
PagedGrowableWriter newEntries =
new PagedGrowableWriter(
2 * entries.size(),
BLOCK_SIZE_BYTES,
PackedInts.bitsRequired(lastNodeAddress),
PackedInts.COMPACT);
long newMask = newEntries.size() - 1;
for (long idx = 0; idx < entries.size(); idx++) {
long address = entries.get(idx);
if (address != 0) {
long pos = hash(address) & newMask;
int c = 0;
while (true) {
if (newEntries.get(pos) == 0) {
newEntries.set(pos, address);
break;
}
// quadratic probe
pos = (pos + (++c)) & newMask;
}
}
}
mask = newMask;
entries = newEntries;
private FST.BytesReader getBytesReader(long nodeAddress, long hashSlot) {
// make sure the nodeAddress and hashSlot is consistent
assert fstNodeAddress.get(hashSlot) == nodeAddress;
long localAddress = copiedNodeAddress.get(hashSlot);
bytesReader.setPosDelta(nodeAddress - localAddress);
return bytesReader;
}
}
}

View File

@ -79,6 +79,7 @@ public class TestByteBlockPool extends LuceneTestCase {
ByteBlockPool pool = new ByteBlockPool(new ByteBlockPool.DirectTrackingAllocator(bytesUsed));
pool.nextBuffer();
long totalBytes = 0;
List<byte[]> items = new ArrayList<>();
for (int i = 0; i < 100; i++) {
int size;
@ -91,6 +92,10 @@ public class TestByteBlockPool extends LuceneTestCase {
random().nextBytes(bytes);
items.add(bytes);
pool.append(new BytesRef(bytes));
totalBytes += size;
// make sure we report the correct position
assertEquals(totalBytes, pool.getPosition());
}
long position = 0;