From 3003731aa94225b3a07cb9fb667fc8fd62bb2d10 Mon Sep 17 00:00:00 2001
From: Adrien Grand <jpountz@gmail.com>
Date: Fri, 10 May 2024 13:42:35 +0200
Subject: [PATCH] Add IndexInput#prefetch. (#13337)

This adds `IndexInput#prefetch`, which is an optional operation that instructs
the `IndexInput` to start fetching bytes from storage in the background. These
bytes will be picked up by follow-up calls to the `IndexInput#readXXX` methods.
In the future, this will help Lucene move from a maximum of one I/O operation
per search thread to one I/O operation per search thread per `IndexInput`.
Typically, when running a query on two terms, the I/O into the terms dictionary
is sequential today. In the future, we would ideally do these I/Os in parallel
using this new API. Note that this will require API changes to some classes
including `TermsEnum`.

I settled on this API because it's simple and wouldn't require making all
Lucene APIs asynchronous to take advantage of extra I/O concurrency, which I
worry would make the query evaluation logic too complicated.

This change will require follow-ups to start using this new API when working
with terms dictionaries, postings, etc.

Relates #13179

Co-authored-by: Uwe Schindler <uschindler@apache.org>
---
 lucene/CHANGES.txt                            |  4 ++
 .../org/apache/lucene/store/IndexInput.java   | 11 ++++
 .../lucene/store/MemorySegmentIndexInput.java | 45 ++++++++++++++
 .../MemorySegmentIndexInputProvider.java      |  5 +-
 .../org/apache/lucene/store/NativeAccess.java |  8 +++
 .../lucene/store/PosixNativeAccess.java       | 44 ++++++++++---
 .../tests/store/BaseDirectoryTestCase.java    | 62 +++++++++++++++++++
 .../tests/store/MockIndexInputWrapper.java    |  6 ++
 8 files changed, 174 insertions(+), 11 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index a7c29cef3cd..22570d8c8e1 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -102,6 +102,10 @@ API Changes
   Additionally, deprecated methods have been removed from ByteBuffersIndexInput, BooleanQuery and others. Please refer
   to MIGRATE.md for further details. (Sanjay Dutt)
 
+* GITHUB#13337: Introduce new `IndexInput#prefetch(long)` API to give a hint to
+  the directory about bytes that are about to be read. (Adrien Grand, Uwe
+  Schindler)
+
 New Features
 ---------------------
 
diff --git a/lucene/core/src/java/org/apache/lucene/store/IndexInput.java b/lucene/core/src/java/org/apache/lucene/store/IndexInput.java
index 3f703bc54b2..ec7a1294d40 100644
--- a/lucene/core/src/java/org/apache/lucene/store/IndexInput.java
+++ b/lucene/core/src/java/org/apache/lucene/store/IndexInput.java
@@ -191,4 +191,15 @@ public abstract class IndexInput extends DataInput implements Closeable {
       };
     }
   }
+
+  /**
+   * Optional method: Give a hint to this input that some bytes will be read in the near future.
+   * IndexInput implementations may take advantage of this hint to start fetching pages of data
+   * immediately from storage.
+   *
+   * <p>The default implementation is a no-op.
+   *
+   * @param length the number of bytes to prefetch
+   */
+  public void prefetch(long length) throws IOException {}
 }
diff --git a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java
index 7d1e2572fdb..a8b8d6da3cd 100644
--- a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java
+++ b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java
@@ -24,6 +24,7 @@ import java.lang.foreign.ValueLayout;
 import java.nio.ByteOrder;
 import java.util.Arrays;
 import java.util.Objects;
+import java.util.Optional;
 import org.apache.lucene.util.ArrayUtil;
 import org.apache.lucene.util.GroupVIntUtil;
 
@@ -44,6 +45,7 @@ abstract class MemorySegmentIndexInput extends IndexInput implements RandomAcces
       ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
   static final ValueLayout.OfFloat LAYOUT_LE_FLOAT =
       ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
+  private static final Optional<NativeAccess> NATIVE_ACCESS = NativeAccess.getImplementation();
 
   final long length;
   final long chunkSizeMask;
@@ -310,6 +312,49 @@ abstract class MemorySegmentIndexInput extends IndexInput implements RandomAcces
     }
   }
 
+  @Override
+  public void prefetch(long length) throws IOException {
+    ensureOpen();
+
+    Objects.checkFromIndexSize(getFilePointer(), length, length());
+
+    if (NATIVE_ACCESS.isEmpty()) {
+      return;
+    }
+    final NativeAccess nativeAccess = NATIVE_ACCESS.get();
+
+    // If at the boundary between two chunks, move to the next one.
+    seek(getFilePointer());
+    try {
+      // Compute the intersection of the current segment and the region that should be prefetched.
+      long offset = curPosition;
+      if (offset + length > curSegment.byteSize()) {
+        // Only prefetch bytes that are stored in the current segment. There may be bytes on the
+        // next segment but this case is rare enough that we don't try to optimize it and keep
+        // things simple instead.
+        length = curSegment.byteSize() - curPosition;
+      }
+      // Now align offset with the page size, this is required for madvise.
+      // Compute the offset of the current position in the OS's page.
+      final long offsetInPage = (curSegment.address() + offset) % nativeAccess.getPageSize();
+      offset -= offsetInPage;
+      length += offsetInPage;
+      if (offset < 0) {
+        // The start of the page is outside of this segment, ignore.
+        return;
+      }
+
+      final MemorySegment prefetchSlice = curSegment.asSlice(offset, length);
+      nativeAccess.madviseWillNeed(prefetchSlice);
+    } catch (
+        @SuppressWarnings("unused")
+        IndexOutOfBoundsException e) {
+      throw new EOFException("Read past EOF: " + this);
+    } catch (NullPointerException | IllegalStateException e) {
+      throw alreadyClosed(e);
+    }
+  }
+
   @Override
   public byte readByte(long pos) throws IOException {
     try {
diff --git a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInputProvider.java b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInputProvider.java
index 887956f306c..e1655101d75 100644
--- a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInputProvider.java
+++ b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInputProvider.java
@@ -111,10 +111,11 @@ final class MemorySegmentIndexInputProvider implements MMapDirectory.MMapIndexIn
         throw convertMapFailedIOException(ioe, resourceDescription, segSize);
       }
       // if preload apply it without madvise.
-      // if chunk size is too small (2 MiB), disable madvise support (incorrect alignment)
+      // skip madvise if the address of our segment is not page-aligned (small segments due to
+      // internal FileChannel logic)
       if (preload) {
         segment.load();
-      } else if (nativeAccess.isPresent() && chunkSizePower >= 21) {
+      } else if (nativeAccess.filter(na -> segment.address() % na.getPageSize() == 0).isPresent()) {
         nativeAccess.get().madvise(segment, readAdvice);
       }
       segments[segNr] = segment;
diff --git a/lucene/core/src/java21/org/apache/lucene/store/NativeAccess.java b/lucene/core/src/java21/org/apache/lucene/store/NativeAccess.java
index f4bc4f89d58..affc0e2ac71 100644
--- a/lucene/core/src/java21/org/apache/lucene/store/NativeAccess.java
+++ b/lucene/core/src/java21/org/apache/lucene/store/NativeAccess.java
@@ -27,6 +27,14 @@ abstract class NativeAccess {
   /** Invoke the {@code madvise} call for the given {@link MemorySegment}. */
   public abstract void madvise(MemorySegment segment, ReadAdvice readAdvice) throws IOException;
 
+  /**
+   * Invoke the {@code madvise} call for the given {@link MemorySegment} with {@code MADV_WILLNEED}.
+   */
+  public abstract void madviseWillNeed(MemorySegment segment) throws IOException;
+
+  /** Returns native page size. */
+  public abstract int getPageSize();
+
   /**
    * Return the NativeAccess instance for this platform. At moment we only support Linux and MacOS
    */
diff --git a/lucene/core/src/java21/org/apache/lucene/store/PosixNativeAccess.java b/lucene/core/src/java21/org/apache/lucene/store/PosixNativeAccess.java
index b74bd7fe365..93caca788b1 100644
--- a/lucene/core/src/java21/org/apache/lucene/store/PosixNativeAccess.java
+++ b/lucene/core/src/java21/org/apache/lucene/store/PosixNativeAccess.java
@@ -50,6 +50,7 @@ final class PosixNativeAccess extends NativeAccess {
   public static final int POSIX_MADV_DONTNEED = 4;
 
   private static final MethodHandle MH$posix_madvise;
+  private static final int PAGE_SIZE;
 
   private static final Optional<NativeAccess> INSTANCE;
 
@@ -60,10 +61,14 @@ final class PosixNativeAccess extends NativeAccess {
   }
 
   static {
+    final Linker linker = Linker.nativeLinker();
+    final SymbolLookup stdlib = linker.defaultLookup();
     MethodHandle adviseHandle = null;
+    int pagesize = -1;
     PosixNativeAccess instance = null;
     try {
-      adviseHandle = lookupMadvise();
+      adviseHandle = lookupMadvise(linker, stdlib);
+      pagesize = (int) lookupGetPageSize(linker, stdlib).invokeExact();
       instance = new PosixNativeAccess();
     } catch (UnsupportedOperationException uoe) {
       LOG.warning(uoe.getMessage());
@@ -77,14 +82,17 @@ final class PosixNativeAccess extends NativeAccess {
                   + "pass the following on command line: --enable-native-access=%s",
               Optional.ofNullable(PosixNativeAccess.class.getModule().getName())
                   .orElse("ALL-UNNAMED")));
+    } catch (RuntimeException | Error e) {
+      throw e;
+    } catch (Throwable e) {
+      throw new AssertionError(e);
     }
     MH$posix_madvise = adviseHandle;
+    PAGE_SIZE = pagesize;
     INSTANCE = Optional.ofNullable(instance);
   }
 
-  private static MethodHandle lookupMadvise() {
-    final Linker linker = Linker.nativeLinker();
-    final SymbolLookup stdlib = linker.defaultLookup();
+  private static MethodHandle lookupMadvise(Linker linker, SymbolLookup stdlib) {
     return findFunction(
         linker,
         stdlib,
@@ -96,6 +104,10 @@ final class PosixNativeAccess extends NativeAccess {
             ValueLayout.JAVA_INT));
   }
 
+  private static MethodHandle lookupGetPageSize(Linker linker, SymbolLookup stdlib) {
+    return findFunction(linker, stdlib, "getpagesize", FunctionDescriptor.of(ValueLayout.JAVA_INT));
+  }
+
   private static MethodHandle findFunction(
       Linker linker, SymbolLookup lookup, String name, FunctionDescriptor desc) {
     final MemorySegment symbol =
@@ -110,17 +122,26 @@ final class PosixNativeAccess extends NativeAccess {
 
   @Override
   public void madvise(MemorySegment segment, ReadAdvice readAdvice) throws IOException {
-    // Note: madvise is bypassed if the segment should be preloaded via MemorySegment#load.
-    if (segment.byteSize() == 0L) {
-      return; // empty segments should be excluded, because they may have no address at all
-    }
     final Integer advice = mapReadAdvice(readAdvice);
     if (advice == null) {
       return; // do nothing
     }
+    madvise(segment, advice);
+  }
+
+  @Override
+  public void madviseWillNeed(MemorySegment segment) throws IOException {
+    madvise(segment, POSIX_MADV_WILLNEED);
+  }
+
+  private void madvise(MemorySegment segment, int advice) throws IOException {
+    // Note: madvise is bypassed if the segment should be preloaded via MemorySegment#load.
+    if (segment.byteSize() == 0L) {
+      return; // empty segments should be excluded, because they may have no address at all
+    }
     final int ret;
     try {
-      ret = (int) MH$posix_madvise.invokeExact(segment, segment.byteSize(), advice.intValue());
+      ret = (int) MH$posix_madvise.invokeExact(segment, segment.byteSize(), advice);
     } catch (Throwable th) {
       throw new AssertionError(th);
     }
@@ -143,4 +164,9 @@ final class PosixNativeAccess extends NativeAccess {
       case RANDOM_PRELOAD -> null;
     };
   }
+
+  @Override
+  public int getPageSize() {
+    return PAGE_SIZE;
+  }
 }
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/store/BaseDirectoryTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/store/BaseDirectoryTestCase.java
index 24d8db0b02f..775c6d5d318 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/store/BaseDirectoryTestCase.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/store/BaseDirectoryTestCase.java
@@ -58,6 +58,7 @@ import org.apache.lucene.tests.mockfile.ExtrasFS;
 import org.apache.lucene.tests.util.LuceneTestCase;
 import org.apache.lucene.tests.util.TestUtil;
 import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.BitUtil;
 import org.apache.lucene.util.IOUtils;
 import org.apache.lucene.util.packed.PackedInts;
 import org.junit.Assert;
@@ -1512,4 +1513,65 @@ public abstract class BaseDirectoryTestCase extends LuceneTestCase {
     dir.deleteFile("group-varint");
     dir.deleteFile("vint");
   }
+
+  public void testPrefetch() throws IOException {
+    doTestPrefetch(0);
+  }
+
+  public void testPrefetchOnSlice() throws IOException {
+    doTestPrefetch(TestUtil.nextInt(random(), 1, 1024));
+  }
+
+  private void doTestPrefetch(int startOffset) throws IOException {
+    try (Directory dir = getDirectory(createTempDir())) {
+      final int totalLength = startOffset + TestUtil.nextInt(random(), 16384, 65536);
+      byte[] arr = new byte[totalLength];
+      random().nextBytes(arr);
+      try (IndexOutput out = dir.createOutput("temp.bin", IOContext.DEFAULT)) {
+        out.writeBytes(arr, arr.length);
+      }
+      byte[] temp = new byte[2048];
+
+      try (IndexInput orig = dir.openInput("temp.bin", IOContext.DEFAULT)) {
+        IndexInput in;
+        if (startOffset == 0) {
+          in = orig.clone();
+        } else {
+          in = orig.slice("slice", startOffset, totalLength - startOffset);
+        }
+        for (int i = 0; i < 10_000; ++i) {
+          final int startPointer = (int) in.getFilePointer();
+          assertTrue(startPointer < in.length());
+          if (random().nextBoolean()) {
+            final long prefetchLength = TestUtil.nextLong(random(), 1, in.length() - startPointer);
+            in.prefetch(prefetchLength);
+          }
+          assertEquals(startPointer, in.getFilePointer());
+          switch (random().nextInt(100)) {
+            case 0:
+              assertEquals(arr[startOffset + startPointer], in.readByte());
+              break;
+            case 1:
+              if (in.length() - startPointer >= Long.BYTES) {
+                assertEquals(
+                    (long) BitUtil.VH_LE_LONG.get(arr, startOffset + startPointer), in.readLong());
+              }
+              break;
+            default:
+              final int readLength =
+                  TestUtil.nextInt(
+                      random(), 1, (int) Math.min(temp.length, in.length() - startPointer));
+              in.readBytes(temp, 0, readLength);
+              assertArrayEquals(
+                  ArrayUtil.copyOfSubArray(
+                      arr, startOffset + startPointer, startOffset + startPointer + readLength),
+                  ArrayUtil.copyOfSubArray(temp, 0, readLength));
+          }
+          if (in.getFilePointer() == in.length() || random().nextBoolean()) {
+            in.seek(TestUtil.nextInt(random(), 0, (int) in.length() - 1));
+          }
+        }
+      }
+    }
+  }
 }
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/store/MockIndexInputWrapper.java b/lucene/test-framework/src/java/org/apache/lucene/tests/store/MockIndexInputWrapper.java
index 39c41d46825..7bf3bf56055 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/store/MockIndexInputWrapper.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/store/MockIndexInputWrapper.java
@@ -130,6 +130,12 @@ public class MockIndexInputWrapper extends FilterIndexInput {
     in.seek(pos);
   }
 
+  @Override
+  public void prefetch(long length) throws IOException {
+    ensureOpen();
+    in.prefetch(length);
+  }
+
   @Override
   public long length() {
     ensureOpen();