diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/FSTCompiler.java b/lucene/core/src/java/org/apache/lucene/util/fst/FSTCompiler.java index 961511616ab..13ca4f6a14e 100644 --- a/lucene/core/src/java/org/apache/lucene/util/fst/FSTCompiler.java +++ b/lucene/core/src/java/org/apache/lucene/util/fst/FSTCompiler.java @@ -16,8 +16,6 @@ */ package org.apache.lucene.util.fst; -import static org.apache.lucene.store.ByteBuffersDataOutput.ALLOCATE_BB_ON_HEAP; -import static org.apache.lucene.store.ByteBuffersDataOutput.NO_REUSE; import static org.apache.lucene.util.fst.FST.ARCS_FOR_BINARY_SEARCH; import static org.apache.lucene.util.fst.FST.ARCS_FOR_CONTINUOUS; import static org.apache.lucene.util.fst.FST.ARCS_FOR_DIRECT_ADDRESSING; @@ -34,7 +32,6 @@ import static org.apache.lucene.util.fst.FST.getNumPresenceBytes; import java.io.IOException; import java.util.Objects; import org.apache.lucene.store.ByteArrayDataOutput; -import org.apache.lucene.store.ByteBuffersDataOutput; import org.apache.lucene.store.DataOutput; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.ArrayUtil; @@ -153,8 +150,7 @@ public class FSTCompiler { * @return the DataOutput */ public static DataOutput getOnHeapReaderWriter(int blockBits) { - return new ReadWriteDataOutput( - new ByteBuffersDataOutput(blockBits, blockBits, ALLOCATE_BB_ON_HEAP, NO_REUSE)); + return new ReadWriteDataOutput(blockBits); } private FSTCompiler( diff --git a/lucene/core/src/java/org/apache/lucene/util/fst/ReadWriteDataOutput.java b/lucene/core/src/java/org/apache/lucene/util/fst/ReadWriteDataOutput.java index f5792779d00..a43c2f4f04d 100644 --- a/lucene/core/src/java/org/apache/lucene/util/fst/ReadWriteDataOutput.java +++ b/lucene/core/src/java/org/apache/lucene/util/fst/ReadWriteDataOutput.java @@ -16,8 +16,12 @@ */ package org.apache.lucene.util.fst; +import static org.apache.lucene.store.ByteBuffersDataOutput.ALLOCATE_BB_ON_HEAP; +import static org.apache.lucene.store.ByteBuffersDataOutput.NO_REUSE; + import java.io.IOException; -import org.apache.lucene.store.ByteBuffersDataInput; +import java.nio.ByteBuffer; +import java.util.List; import org.apache.lucene.store.ByteBuffersDataOutput; import org.apache.lucene.store.DataOutput; @@ -28,13 +32,19 @@ import org.apache.lucene.store.DataOutput; final class ReadWriteDataOutput extends DataOutput implements FSTReader { private final ByteBuffersDataOutput dataOutput; - // the DataInput to read from once we finish writing - private ByteBuffersDataInput dataInput; + private final int blockBits; + private final int blockSize; + private final int blockMask; + private List byteBuffers; // whether this DataOutput is already frozen private boolean frozen; - public ReadWriteDataOutput(ByteBuffersDataOutput dataOutput) { - this.dataOutput = dataOutput; + public ReadWriteDataOutput(int blockBits) { + this.dataOutput = + new ByteBuffersDataOutput(blockBits, blockBits, ALLOCATE_BB_ON_HEAP, NO_REUSE); + this.blockBits = blockBits; + this.blockSize = 1 << blockBits; + this.blockMask = blockSize - 1; } @Override @@ -56,14 +66,62 @@ final class ReadWriteDataOutput extends DataOutput implements FSTReader { public void freeze() { frozen = true; - // this operation are costly, so we want to compute it once and cache - dataInput = dataOutput.toDataInput(); + // this operation is costly, so we want to compute it once and cache + this.byteBuffers = dataOutput.toWriteableBufferList(); + // ensure the ByteBuffer internal array is accessible. The call to toWriteableBufferList() above + // would ensure that it is accessible. + assert byteBuffers.stream().allMatch(ByteBuffer::hasArray); } @Override public FST.BytesReader getReverseBytesReader() { - assert dataInput != null; // freeze() must be called first - return new ReverseRandomAccessReader(dataInput); + assert byteBuffers != null; // freeze() must be called first + if (byteBuffers.size() == 1) { + // use a faster implementation for single-block case + return new ReverseBytesReader(byteBuffers.get(0).array()); + } + return new FST.BytesReader() { + private byte[] current = byteBuffers.get(0).array(); + private int nextBuffer = -1; + private int nextRead; + + @Override + public byte readByte() { + if (nextRead == -1) { + current = byteBuffers.get(nextBuffer--).array(); + nextRead = blockSize - 1; + } + return current[nextRead--]; + } + + @Override + public void skipBytes(long count) { + setPosition(getPosition() - count); + } + + @Override + public void readBytes(byte[] b, int offset, int len) { + for (int i = 0; i < len; i++) { + b[offset + i] = readByte(); + } + } + + @Override + public long getPosition() { + return ((long) nextBuffer + 1) * blockSize + nextRead; + } + + @Override + public void setPosition(long pos) { + int bufferIndex = (int) (pos >> blockBits); + if (nextBuffer != bufferIndex - 1) { + nextBuffer = bufferIndex - 1; + current = byteBuffers.get(bufferIndex).array(); + } + nextRead = (int) (pos & blockMask); + assert getPosition() == pos : "pos=" + pos + " getPos()=" + getPosition(); + } + }; } @Override