From 5896e5389a83f657781875a852120615ba4763dc Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Thu, 19 Aug 2021 16:14:06 -0400 Subject: [PATCH] LUCENE-10057: Use Lucene abstractions to store demo KnnVectorDict (Dawid Weiss) --- .../org/apache/lucene/demo/IndexFiles.java | 34 +++++--- .../org/apache/lucene/demo/SearchFiles.java | 5 +- .../apache/lucene/demo/knn/KnnVectorDict.java | 85 ++++++++++--------- lucene/demo/src/java/overview.html | 12 +++ .../test/org/apache/lucene/demo/TestDemo.java | 7 +- .../lucene/demo/knn/TestDemoEmbeddings.java | 42 ++++----- .../lucene/demo/knn/TestKnnVectorDict.java | 31 ++++--- 7 files changed, 122 insertions(+), 94 deletions(-) diff --git a/lucene/demo/src/java/org/apache/lucene/demo/IndexFiles.java b/lucene/demo/src/java/org/apache/lucene/demo/IndexFiles.java index 1e23c5d9b84..71a63839cee 100644 --- a/lucene/demo/src/java/org/apache/lucene/demo/IndexFiles.java +++ b/lucene/demo/src/java/org/apache/lucene/demo/IndexFiles.java @@ -47,6 +47,7 @@ import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.util.IOUtils; /** * Index all text files under a directory. @@ -55,17 +56,18 @@ import org.apache.lucene.store.FSDirectory; * command-line arguments for usage information. */ public class IndexFiles implements AutoCloseable { + static final String KNN_DICT = "knn-dict"; // Calculates embedding vectors for KnnVector search private final DemoEmbeddings demoEmbeddings; private final KnnVectorDict vectorDict; - private IndexFiles(Path vectorDictPath) throws IOException { - if (vectorDictPath != null) { - vectorDict = new KnnVectorDict(vectorDictPath); + private IndexFiles(KnnVectorDict vectorDict) throws IOException { + if (vectorDict != null) { + this.vectorDict = vectorDict; demoEmbeddings = new DemoEmbeddings(vectorDict); } else { - vectorDict = null; + this.vectorDict = null; demoEmbeddings = null; } } @@ -80,7 +82,7 @@ public class IndexFiles implements AutoCloseable { + "IF DICT_PATH contains a KnnVector dictionary, the index will also support KnnVector search"; String indexPath = "index"; String docsPath = null; - Path vectorDictPath = null; + String vectorDictSource = null; boolean create = true; for (int i = 0; i < args.length; i++) { switch (args[i]) { @@ -91,7 +93,7 @@ public class IndexFiles implements AutoCloseable { docsPath = args[++i]; break; case "-knn_dict": - vectorDictPath = Paths.get(args[++i]); + vectorDictSource = args[++i]; break; case "-update": create = false; @@ -142,8 +144,16 @@ public class IndexFiles implements AutoCloseable { // // iwc.setRAMBufferSizeMB(256.0); + KnnVectorDict vectorDictInstance = null; + long vectorDictSize = 0; + if (vectorDictSource != null) { + KnnVectorDict.build(Paths.get(vectorDictSource), dir, KNN_DICT); + vectorDictInstance = new KnnVectorDict(dir, KNN_DICT); + vectorDictSize = vectorDictInstance.ramBytesUsed(); + } + try (IndexWriter writer = new IndexWriter(dir, iwc); - IndexFiles indexFiles = new IndexFiles(vectorDictPath)) { + IndexFiles indexFiles = new IndexFiles(vectorDictInstance)) { indexFiles.indexDocs(writer, docDir); // NOTE: if you want to maximize search performance, @@ -153,6 +163,8 @@ public class IndexFiles implements AutoCloseable { // you're done adding documents to it): // // writer.forceMerge(1); + } finally { + IOUtils.close(vectorDictInstance); } Date end = new Date(); @@ -163,6 +175,10 @@ public class IndexFiles implements AutoCloseable { + " documents in " + (end.getTime() - start.getTime()) + " milliseconds"); + if (reader.numDocs() > 100 && vectorDictSize < 1_000_000) { + throw new RuntimeException( + "Are you (ab)using the toy vector dictionary? See the package javadocs to understand why you got this exception."); + } } } catch (IOException e) { System.out.println(" caught a " + e.getClass() + "\n with message: " + e.getMessage()); @@ -263,8 +279,6 @@ public class IndexFiles implements AutoCloseable { @Override public void close() throws IOException { - if (vectorDict != null) { - vectorDict.close(); - } + IOUtils.close(vectorDict); } } diff --git a/lucene/demo/src/java/org/apache/lucene/demo/SearchFiles.java b/lucene/demo/src/java/org/apache/lucene/demo/SearchFiles.java index eeaaa95176e..e6195c9a801 100644 --- a/lucene/demo/src/java/org/apache/lucene/demo/SearchFiles.java +++ b/lucene/demo/src/java/org/apache/lucene/demo/SearchFiles.java @@ -31,7 +31,6 @@ import org.apache.lucene.demo.knn.DemoEmbeddings; import org.apache.lucene.demo.knn.KnnVectorDict; import org.apache.lucene.document.Document; import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; import org.apache.lucene.queryparser.classic.QueryParser; import org.apache.lucene.search.BooleanClause; @@ -103,12 +102,12 @@ public class SearchFiles { } } - IndexReader reader = DirectoryReader.open(FSDirectory.open(Paths.get(index))); + DirectoryReader reader = DirectoryReader.open(FSDirectory.open(Paths.get(index))); IndexSearcher searcher = new IndexSearcher(reader); Analyzer analyzer = new StandardAnalyzer(); KnnVectorDict vectorDict = null; if (knnVectors > 0) { - vectorDict = new KnnVectorDict(Paths.get(index).resolve("knn-dict")); + vectorDict = new KnnVectorDict(reader.directory(), IndexFiles.KNN_DICT); } BufferedReader in; if (queries != null) { diff --git a/lucene/demo/src/java/org/apache/lucene/demo/knn/KnnVectorDict.java b/lucene/demo/src/java/org/apache/lucene/demo/knn/KnnVectorDict.java index 1601ae7b4c2..116fea0daf7 100644 --- a/lucene/demo/src/java/org/apache/lucene/demo/knn/KnnVectorDict.java +++ b/lucene/demo/src/java/org/apache/lucene/demo/knn/KnnVectorDict.java @@ -17,17 +17,19 @@ package org.apache.lucene.demo.knn; import java.io.BufferedReader; -import java.io.DataOutputStream; +import java.io.Closeable; import java.io.IOException; -import java.io.OutputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; -import java.nio.channels.FileChannel; import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; import java.util.regex.Pattern; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IntsRefBuilder; import org.apache.lucene.util.VectorUtil; @@ -40,32 +42,29 @@ import org.apache.lucene.util.fst.Util; * Manages a map from token to numeric vector for use with KnnVector indexing and search. The map is * stored as an FST: token-to-ordinal plus a dense binary file holding the vectors. */ -public class KnnVectorDict implements AutoCloseable { +public class KnnVectorDict implements Closeable { private final FST fst; - private final FileChannel vectors; - private final ByteBuffer vbuffer; + private final IndexInput vectors; private final int dimension; /** * Sole constructor * - * @param knnDictPath the base path name of the files that will store the KnnVectorDict. The file - * with extension '.bin' holds the vectors and the '.fst' maps tokens to offsets in the '.bin' - * file. + * @param directory Lucene directory from which knn directory should be read. + * @param dictName the base name of the directory files that store the knn vector dictionary. A + * file with extension '.bin' holds the vectors and the '.fst' maps tokens to offsets in the + * '.bin' file. */ - public KnnVectorDict(Path knnDictPath) throws IOException { - String dictName = knnDictPath.getFileName().toString(); - Path fstPath = knnDictPath.resolveSibling(dictName + ".fst"); - Path binPath = knnDictPath.resolveSibling(dictName + ".bin"); - fst = FST.read(fstPath, PositiveIntOutputs.getSingleton()); - vectors = FileChannel.open(binPath); - long size = vectors.size(); - if (size > Integer.MAX_VALUE) { - throw new IllegalArgumentException("vector file is too large: " + size + " bytes"); + public KnnVectorDict(Directory directory, String dictName) throws IOException { + try (IndexInput fstIn = directory.openInput(dictName + ".fst", IOContext.READ)) { + fst = new FST<>(fstIn, fstIn, PositiveIntOutputs.getSingleton()); } - vbuffer = vectors.map(FileChannel.MapMode.READ_ONLY, 0, size); - dimension = vbuffer.getInt((int) (size - Integer.BYTES)); + + vectors = directory.openInput(dictName + ".bin", IOContext.READ); + long size = vectors.length(); + vectors.seek(size - Integer.BYTES); + dimension = vectors.readInt(); if ((size - Integer.BYTES) % (dimension * Float.BYTES) != 0) { throw new IllegalStateException( "vector file size " + size + " is not consonant with the vector dimension " + dimension); @@ -96,8 +95,8 @@ public class KnnVectorDict implements AutoCloseable { if (ord == null) { Arrays.fill(output, (byte) 0); } else { - vbuffer.position((int) (ord * dimension * Float.BYTES)); - vbuffer.get(output); + vectors.seek(ord * dimension * Float.BYTES); + vectors.readBytes(output, 0, output.length); } } @@ -122,11 +121,12 @@ public class KnnVectorDict implements AutoCloseable { * and each line is space-delimited. The first column has the token, and the remaining columns * are the vector components, as text. The dictionary must be sorted by its leading tokens * (considered as bytes). - * @param dictOutput a dictionary path prefix. The output will be two files, named by appending - * '.fst' and '.bin' to this path. + * @param directory a Lucene directory to write the dictionary to. + * @param dictName Base name for the knn dictionary files. */ - public static void build(Path gloveInput, Path dictOutput) throws IOException { - new Builder().build(gloveInput, dictOutput); + public static void build(Path gloveInput, Directory directory, String dictName) + throws IOException { + new Builder().build(gloveInput, directory, dictName); } private static class Builder { @@ -140,25 +140,20 @@ public class KnnVectorDict implements AutoCloseable { private long ordinal = 1; private int numFields; - void build(Path gloveInput, Path dictOutput) throws IOException { - String dictName = dictOutput.getFileName().toString(); - Path fstPath = dictOutput.resolveSibling(dictName + ".fst"); - Path binPath = dictOutput.resolveSibling(dictName + ".bin"); + void build(Path gloveInput, Directory directory, String dictName) throws IOException { try (BufferedReader in = Files.newBufferedReader(gloveInput); - OutputStream binOut = Files.newOutputStream(binPath); - DataOutputStream binDataOut = new DataOutputStream(binOut)) { + IndexOutput binOut = directory.createOutput(dictName + ".bin", IOContext.DEFAULT); + IndexOutput fstOut = directory.createOutput(dictName + ".fst", IOContext.DEFAULT)) { writeFirstLine(in, binOut); - while (true) { - if (addOneLine(in, binOut) == false) { - break; - } + while (addOneLine(in, binOut)) { + // continue; } - fstCompiler.compile().save(fstPath); - binDataOut.writeInt(numFields - 1); + fstCompiler.compile().save(fstOut, fstOut); + binOut.writeInt(numFields - 1); } } - private void writeFirstLine(BufferedReader in, OutputStream out) throws IOException { + private void writeFirstLine(BufferedReader in, IndexOutput out) throws IOException { String[] fields = readOneLine(in); if (fields == null) { return; @@ -178,7 +173,7 @@ public class KnnVectorDict implements AutoCloseable { return SPACE_RE.split(line, 0); } - private boolean addOneLine(BufferedReader in, OutputStream out) throws IOException { + private boolean addOneLine(BufferedReader in, IndexOutput out) throws IOException { String[] fields = readOneLine(in); if (fields == null) { return false; @@ -197,7 +192,7 @@ public class KnnVectorDict implements AutoCloseable { return true; } - private void writeVector(String[] fields, OutputStream out) throws IOException { + private void writeVector(String[] fields, IndexOutput out) throws IOException { byteBuffer.position(0); FloatBuffer floatBuffer = byteBuffer.asFloatBuffer(); for (int i = 1; i < fields.length; i++) { @@ -205,7 +200,13 @@ public class KnnVectorDict implements AutoCloseable { } VectorUtil.l2normalize(scratch); floatBuffer.put(scratch); - out.write(byteBuffer.array()); + byte[] bytes = byteBuffer.array(); + out.writeBytes(bytes, bytes.length); } } + + /** Return the size of the dictionary in bytes */ + public long ramBytesUsed() { + return fst.ramBytesUsed() + vectors.length(); + } } diff --git a/lucene/demo/src/java/overview.html b/lucene/demo/src/java/overview.html index c20d096abf6..7264e5c089d 100644 --- a/lucene/demo/src/java/overview.html +++ b/lucene/demo/src/java/overview.html @@ -32,6 +32,7 @@
  • Location of the source
  • IndexFiles
  • Searching Files
  • +
  • Working with vector embeddings
  • @@ -203,6 +204,17 @@ IndexSearcher.search(query,n)} method that returns n hits. The results are printed in pages, sorted by score (i.e. relevance).

    +

    Working with vector embeddings

    +
    +

    In addition to indexing and searching text, IndexFiles and SearchFiles can also index and search + numeric vectors derived from that text, known as "embeddings." This demo code uses pre-computed embeddings + provided by the GloVe project, which are in the public + domain. The dictionary here is a tiny subset of the full GloVe dataset. It includes only the words that occur + in the toy data set, and is definitely not ready for production use! If you use this code to create + a vector index for a larger document set, the indexer will throw an exception because + a more complete set of embeddings is needed to get reasonable results. +

    +
    diff --git a/lucene/demo/src/test/org/apache/lucene/demo/TestDemo.java b/lucene/demo/src/test/org/apache/lucene/demo/TestDemo.java index 9d0d0dd90fc..8a2ad5073bd 100644 --- a/lucene/demo/src/test/org/apache/lucene/demo/TestDemo.java +++ b/lucene/demo/src/test/org/apache/lucene/demo/TestDemo.java @@ -20,7 +20,6 @@ import java.io.ByteArrayOutputStream; import java.io.PrintStream; import java.nio.charset.Charset; import java.nio.file.Path; -import org.apache.lucene.demo.knn.KnnVectorDict; import org.apache.lucene.util.LuceneTestCase; public class TestDemo extends LuceneTestCase { @@ -90,10 +89,8 @@ public class TestDemo extends LuceneTestCase { public void testKnnVectorSearch() throws Exception { Path dir = getDataPath("test-files/docs"); Path indexDir = createTempDir("ContribDemoTest"); - Path dictPath = indexDir.resolve("knn-dict"); - Path vectorDictSource = getDataPath("test-files/knn-dict").resolve("knn-token-vectors"); - KnnVectorDict.build(vectorDictSource, dictPath); + Path vectorDictSource = getDataPath("test-files/knn-dict").resolve("knn-token-vectors"); IndexFiles.main( new String[] { "-create", @@ -102,7 +99,7 @@ public class TestDemo extends LuceneTestCase { "-index", indexDir.toString(), "-knn_dict", - dictPath.toString() + vectorDictSource.toString() }); // We add a single semantic hit by passing the "-knn_vector 1" argument to SearchFiles. The diff --git a/lucene/demo/src/test/org/apache/lucene/demo/knn/TestDemoEmbeddings.java b/lucene/demo/src/test/org/apache/lucene/demo/knn/TestDemoEmbeddings.java index 0f1c0e0673a..cbdacc3f1ab 100644 --- a/lucene/demo/src/test/org/apache/lucene/demo/knn/TestDemoEmbeddings.java +++ b/lucene/demo/src/test/org/apache/lucene/demo/knn/TestDemoEmbeddings.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.file.Path; +import org.apache.lucene.store.Directory; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.VectorUtil; @@ -28,30 +29,31 @@ public class TestDemoEmbeddings extends LuceneTestCase { public void testComputeEmbedding() throws IOException { Path testVectors = getDataPath("../test-files/knn-dict").resolve("knn-token-vectors"); - Path dictPath = createTempDir("knn-demo").resolve("dict"); - KnnVectorDict.build(testVectors, dictPath); - try (KnnVectorDict dict = new KnnVectorDict(dictPath)) { - DemoEmbeddings demoEmbeddings = new DemoEmbeddings(dict); + try (Directory directory = newDirectory()) { + KnnVectorDict.build(testVectors, directory, "dict"); + try (KnnVectorDict dict = new KnnVectorDict(directory, "dict")) { + DemoEmbeddings demoEmbeddings = new DemoEmbeddings(dict); - // test garbage - float[] garbageVector = - demoEmbeddings.computeEmbedding("garbagethathasneverbeen seeneverinlife"); - assertEquals(50, garbageVector.length); - assertArrayEquals(new float[50], garbageVector, 0); + // test garbage + float[] garbageVector = + demoEmbeddings.computeEmbedding("garbagethathasneverbeen seeneverinlife"); + assertEquals(50, garbageVector.length); + assertArrayEquals(new float[50], garbageVector, 0); - // test space - assertArrayEquals(new float[50], demoEmbeddings.computeEmbedding(" "), 0); + // test space + assertArrayEquals(new float[50], demoEmbeddings.computeEmbedding(" "), 0); - // test some real words that are in the dictionary and some that are not - float[] realVector = demoEmbeddings.computeEmbedding("the real fact"); - assertEquals(50, realVector.length); + // test some real words that are in the dictionary and some that are not + float[] realVector = demoEmbeddings.computeEmbedding("the real fact"); + assertEquals(50, realVector.length); - float[] the = getTermVector(dict, "the"); - assertArrayEquals(new float[50], getTermVector(dict, "real"), 0); - float[] fact = getTermVector(dict, "fact"); - VectorUtil.add(the, fact); - VectorUtil.l2normalize(the); - assertArrayEquals(the, realVector, 0); + float[] the = getTermVector(dict, "the"); + assertArrayEquals(new float[50], getTermVector(dict, "real"), 0); + float[] fact = getTermVector(dict, "fact"); + VectorUtil.add(the, fact); + VectorUtil.l2normalize(the); + assertArrayEquals(the, realVector, 0); + } } } diff --git a/lucene/demo/src/test/org/apache/lucene/demo/knn/TestKnnVectorDict.java b/lucene/demo/src/test/org/apache/lucene/demo/knn/TestKnnVectorDict.java index 56aafa5404e..351fe8bb4c2 100644 --- a/lucene/demo/src/test/org/apache/lucene/demo/knn/TestKnnVectorDict.java +++ b/lucene/demo/src/test/org/apache/lucene/demo/knn/TestKnnVectorDict.java @@ -19,6 +19,7 @@ package org.apache.lucene.demo.knn; import java.io.IOException; import java.nio.file.Path; import java.util.Arrays; +import org.apache.lucene.store.Directory; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.LuceneTestCase; @@ -26,23 +27,25 @@ public class TestKnnVectorDict extends LuceneTestCase { public void testBuild() throws IOException { Path testVectors = getDataPath("../test-files/knn-dict").resolve("knn-token-vectors"); - Path dictPath = createTempDir("knn-demo").resolve("dict"); - KnnVectorDict.build(testVectors, dictPath); - try (KnnVectorDict dict = new KnnVectorDict(dictPath)) { - assertEquals(50, dict.getDimension()); - byte[] vector = new byte[dict.getDimension() * Float.BYTES]; - // not found token has zero vector - dict.get(new BytesRef("never saw this token"), vector); - assertArrayEquals(new byte[200], vector); + try (Directory directory = newDirectory()) { + KnnVectorDict.build(testVectors, directory, "dict"); + try (KnnVectorDict dict = new KnnVectorDict(directory, "dict")) { + assertEquals(50, dict.getDimension()); + byte[] vector = new byte[dict.getDimension() * Float.BYTES]; - // found token has nonzero vector - dict.get(new BytesRef("the"), vector); - assertFalse(Arrays.equals(new byte[200], vector)); + // not found token has zero vector + dict.get(new BytesRef("never saw this token"), vector); + assertArrayEquals(new byte[200], vector); - // incorrect dimension for output buffer - expectThrows( - IllegalArgumentException.class, () -> dict.get(new BytesRef("the"), new byte[10])); + // found token has nonzero vector + dict.get(new BytesRef("the"), vector); + assertFalse(Arrays.equals(new byte[200], vector)); + + // incorrect dimension for output buffer + expectThrows( + IllegalArgumentException.class, () -> dict.get(new BytesRef("the"), new byte[10])); + } } } }