LUCENE-10057: Use Lucene abstractions to store demo KnnVectorDict (Dawid Weiss)

This commit is contained in:
Michael Sokolov 2021-08-19 16:14:06 -04:00 committed by GitHub
parent eeb296ce90
commit 5896e5389a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 122 additions and 94 deletions

View File

@ -47,6 +47,7 @@ import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.IOUtils;
/**
* Index all text files under a directory.
@ -55,17 +56,18 @@ import org.apache.lucene.store.FSDirectory;
* command-line arguments for usage information.
*/
public class IndexFiles implements AutoCloseable {
static final String KNN_DICT = "knn-dict";
// Calculates embedding vectors for KnnVector search
private final DemoEmbeddings demoEmbeddings;
private final KnnVectorDict vectorDict;
private IndexFiles(Path vectorDictPath) throws IOException {
if (vectorDictPath != null) {
vectorDict = new KnnVectorDict(vectorDictPath);
private IndexFiles(KnnVectorDict vectorDict) throws IOException {
if (vectorDict != null) {
this.vectorDict = vectorDict;
demoEmbeddings = new DemoEmbeddings(vectorDict);
} else {
vectorDict = null;
this.vectorDict = null;
demoEmbeddings = null;
}
}
@ -80,7 +82,7 @@ public class IndexFiles implements AutoCloseable {
+ "IF DICT_PATH contains a KnnVector dictionary, the index will also support KnnVector search";
String indexPath = "index";
String docsPath = null;
Path vectorDictPath = null;
String vectorDictSource = null;
boolean create = true;
for (int i = 0; i < args.length; i++) {
switch (args[i]) {
@ -91,7 +93,7 @@ public class IndexFiles implements AutoCloseable {
docsPath = args[++i];
break;
case "-knn_dict":
vectorDictPath = Paths.get(args[++i]);
vectorDictSource = args[++i];
break;
case "-update":
create = false;
@ -142,8 +144,16 @@ public class IndexFiles implements AutoCloseable {
//
// iwc.setRAMBufferSizeMB(256.0);
KnnVectorDict vectorDictInstance = null;
long vectorDictSize = 0;
if (vectorDictSource != null) {
KnnVectorDict.build(Paths.get(vectorDictSource), dir, KNN_DICT);
vectorDictInstance = new KnnVectorDict(dir, KNN_DICT);
vectorDictSize = vectorDictInstance.ramBytesUsed();
}
try (IndexWriter writer = new IndexWriter(dir, iwc);
IndexFiles indexFiles = new IndexFiles(vectorDictPath)) {
IndexFiles indexFiles = new IndexFiles(vectorDictInstance)) {
indexFiles.indexDocs(writer, docDir);
// NOTE: if you want to maximize search performance,
@ -153,6 +163,8 @@ public class IndexFiles implements AutoCloseable {
// you're done adding documents to it):
//
// writer.forceMerge(1);
} finally {
IOUtils.close(vectorDictInstance);
}
Date end = new Date();
@ -163,6 +175,10 @@ public class IndexFiles implements AutoCloseable {
+ " documents in "
+ (end.getTime() - start.getTime())
+ " milliseconds");
if (reader.numDocs() > 100 && vectorDictSize < 1_000_000) {
throw new RuntimeException(
"Are you (ab)using the toy vector dictionary? See the package javadocs to understand why you got this exception.");
}
}
} catch (IOException e) {
System.out.println(" caught a " + e.getClass() + "\n with message: " + e.getMessage());
@ -263,8 +279,6 @@ public class IndexFiles implements AutoCloseable {
@Override
public void close() throws IOException {
if (vectorDict != null) {
vectorDict.close();
}
IOUtils.close(vectorDict);
}
}

View File

@ -31,7 +31,6 @@ import org.apache.lucene.demo.knn.DemoEmbeddings;
import org.apache.lucene.demo.knn.KnnVectorDict;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.queryparser.classic.QueryParser;
import org.apache.lucene.search.BooleanClause;
@ -103,12 +102,12 @@ public class SearchFiles {
}
}
IndexReader reader = DirectoryReader.open(FSDirectory.open(Paths.get(index)));
DirectoryReader reader = DirectoryReader.open(FSDirectory.open(Paths.get(index)));
IndexSearcher searcher = new IndexSearcher(reader);
Analyzer analyzer = new StandardAnalyzer();
KnnVectorDict vectorDict = null;
if (knnVectors > 0) {
vectorDict = new KnnVectorDict(Paths.get(index).resolve("knn-dict"));
vectorDict = new KnnVectorDict(reader.directory(), IndexFiles.KNN_DICT);
}
BufferedReader in;
if (queries != null) {

View File

@ -17,17 +17,19 @@
package org.apache.lucene.demo.knn;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.regex.Pattern;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.VectorUtil;
@ -40,32 +42,29 @@ import org.apache.lucene.util.fst.Util;
* Manages a map from token to numeric vector for use with KnnVector indexing and search. The map is
* stored as an FST: token-to-ordinal plus a dense binary file holding the vectors.
*/
public class KnnVectorDict implements AutoCloseable {
public class KnnVectorDict implements Closeable {
private final FST<Long> fst;
private final FileChannel vectors;
private final ByteBuffer vbuffer;
private final IndexInput vectors;
private final int dimension;
/**
* Sole constructor
*
* @param knnDictPath the base path name of the files that will store the KnnVectorDict. The file
* with extension '.bin' holds the vectors and the '.fst' maps tokens to offsets in the '.bin'
* file.
* @param directory Lucene directory from which knn directory should be read.
* @param dictName the base name of the directory files that store the knn vector dictionary. A
* file with extension '.bin' holds the vectors and the '.fst' maps tokens to offsets in the
* '.bin' file.
*/
public KnnVectorDict(Path knnDictPath) throws IOException {
String dictName = knnDictPath.getFileName().toString();
Path fstPath = knnDictPath.resolveSibling(dictName + ".fst");
Path binPath = knnDictPath.resolveSibling(dictName + ".bin");
fst = FST.read(fstPath, PositiveIntOutputs.getSingleton());
vectors = FileChannel.open(binPath);
long size = vectors.size();
if (size > Integer.MAX_VALUE) {
throw new IllegalArgumentException("vector file is too large: " + size + " bytes");
public KnnVectorDict(Directory directory, String dictName) throws IOException {
try (IndexInput fstIn = directory.openInput(dictName + ".fst", IOContext.READ)) {
fst = new FST<>(fstIn, fstIn, PositiveIntOutputs.getSingleton());
}
vbuffer = vectors.map(FileChannel.MapMode.READ_ONLY, 0, size);
dimension = vbuffer.getInt((int) (size - Integer.BYTES));
vectors = directory.openInput(dictName + ".bin", IOContext.READ);
long size = vectors.length();
vectors.seek(size - Integer.BYTES);
dimension = vectors.readInt();
if ((size - Integer.BYTES) % (dimension * Float.BYTES) != 0) {
throw new IllegalStateException(
"vector file size " + size + " is not consonant with the vector dimension " + dimension);
@ -96,8 +95,8 @@ public class KnnVectorDict implements AutoCloseable {
if (ord == null) {
Arrays.fill(output, (byte) 0);
} else {
vbuffer.position((int) (ord * dimension * Float.BYTES));
vbuffer.get(output);
vectors.seek(ord * dimension * Float.BYTES);
vectors.readBytes(output, 0, output.length);
}
}
@ -122,11 +121,12 @@ public class KnnVectorDict implements AutoCloseable {
* and each line is space-delimited. The first column has the token, and the remaining columns
* are the vector components, as text. The dictionary must be sorted by its leading tokens
* (considered as bytes).
* @param dictOutput a dictionary path prefix. The output will be two files, named by appending
* '.fst' and '.bin' to this path.
* @param directory a Lucene directory to write the dictionary to.
* @param dictName Base name for the knn dictionary files.
*/
public static void build(Path gloveInput, Path dictOutput) throws IOException {
new Builder().build(gloveInput, dictOutput);
public static void build(Path gloveInput, Directory directory, String dictName)
throws IOException {
new Builder().build(gloveInput, directory, dictName);
}
private static class Builder {
@ -140,25 +140,20 @@ public class KnnVectorDict implements AutoCloseable {
private long ordinal = 1;
private int numFields;
void build(Path gloveInput, Path dictOutput) throws IOException {
String dictName = dictOutput.getFileName().toString();
Path fstPath = dictOutput.resolveSibling(dictName + ".fst");
Path binPath = dictOutput.resolveSibling(dictName + ".bin");
void build(Path gloveInput, Directory directory, String dictName) throws IOException {
try (BufferedReader in = Files.newBufferedReader(gloveInput);
OutputStream binOut = Files.newOutputStream(binPath);
DataOutputStream binDataOut = new DataOutputStream(binOut)) {
IndexOutput binOut = directory.createOutput(dictName + ".bin", IOContext.DEFAULT);
IndexOutput fstOut = directory.createOutput(dictName + ".fst", IOContext.DEFAULT)) {
writeFirstLine(in, binOut);
while (true) {
if (addOneLine(in, binOut) == false) {
break;
}
while (addOneLine(in, binOut)) {
// continue;
}
fstCompiler.compile().save(fstPath);
binDataOut.writeInt(numFields - 1);
fstCompiler.compile().save(fstOut, fstOut);
binOut.writeInt(numFields - 1);
}
}
private void writeFirstLine(BufferedReader in, OutputStream out) throws IOException {
private void writeFirstLine(BufferedReader in, IndexOutput out) throws IOException {
String[] fields = readOneLine(in);
if (fields == null) {
return;
@ -178,7 +173,7 @@ public class KnnVectorDict implements AutoCloseable {
return SPACE_RE.split(line, 0);
}
private boolean addOneLine(BufferedReader in, OutputStream out) throws IOException {
private boolean addOneLine(BufferedReader in, IndexOutput out) throws IOException {
String[] fields = readOneLine(in);
if (fields == null) {
return false;
@ -197,7 +192,7 @@ public class KnnVectorDict implements AutoCloseable {
return true;
}
private void writeVector(String[] fields, OutputStream out) throws IOException {
private void writeVector(String[] fields, IndexOutput out) throws IOException {
byteBuffer.position(0);
FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
for (int i = 1; i < fields.length; i++) {
@ -205,7 +200,13 @@ public class KnnVectorDict implements AutoCloseable {
}
VectorUtil.l2normalize(scratch);
floatBuffer.put(scratch);
out.write(byteBuffer.array());
byte[] bytes = byteBuffer.array();
out.writeBytes(bytes, bytes.length);
}
}
/** Return the size of the dictionary in bytes */
public long ramBytesUsed() {
return fst.ramBytesUsed() + vectors.length();
}
}

View File

@ -32,6 +32,7 @@
<li><a href="#Location_of_the_source">Location of the source</a></li>
<li><a href="#IndexFiles">IndexFiles</a></li>
<li><a href="#Searching_Files">Searching Files</a></li>
<li><a href="#Embeddings">Working with vector embeddings</a></li>
</ul>
</div>
<a id="About_this_Document"></a>
@ -203,6 +204,17 @@ IndexSearcher.search(query,n)} method that returns
<span class="codefrag">n</span> hits. The results are printed in pages, sorted
by score (i.e. relevance).</p>
</div>
<h2 id="Embeddings" class="boxed">Working with vector embeddings</h2>
<div class="section">
<p>In addition to indexing and searching text, IndexFiles and SearchFiles can also index and search
numeric vectors derived from that text, known as "embeddings." This demo code uses pre-computed embeddings
provided by the <a href="https://nlp.stanford.edu/projects/glove/">GloVe</a> project, which are in the public
domain. The dictionary here is a tiny subset of the full GloVe dataset. It includes only the words that occur
in the toy data set, and is definitely <i>not ready for production use</i>! If you use this code to create
a vector index for a larger document set, the indexer will throw an exception because
a more complete set of embeddings is needed to get reasonable results.
</p>
</div>
</body>
</html>

View File

@ -20,7 +20,6 @@ import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
import java.nio.charset.Charset;
import java.nio.file.Path;
import org.apache.lucene.demo.knn.KnnVectorDict;
import org.apache.lucene.util.LuceneTestCase;
public class TestDemo extends LuceneTestCase {
@ -90,10 +89,8 @@ public class TestDemo extends LuceneTestCase {
public void testKnnVectorSearch() throws Exception {
Path dir = getDataPath("test-files/docs");
Path indexDir = createTempDir("ContribDemoTest");
Path dictPath = indexDir.resolve("knn-dict");
Path vectorDictSource = getDataPath("test-files/knn-dict").resolve("knn-token-vectors");
KnnVectorDict.build(vectorDictSource, dictPath);
Path vectorDictSource = getDataPath("test-files/knn-dict").resolve("knn-token-vectors");
IndexFiles.main(
new String[] {
"-create",
@ -102,7 +99,7 @@ public class TestDemo extends LuceneTestCase {
"-index",
indexDir.toString(),
"-knn_dict",
dictPath.toString()
vectorDictSource.toString()
});
// We add a single semantic hit by passing the "-knn_vector 1" argument to SearchFiles. The

View File

@ -20,6 +20,7 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Path;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
@ -28,30 +29,31 @@ public class TestDemoEmbeddings extends LuceneTestCase {
public void testComputeEmbedding() throws IOException {
Path testVectors = getDataPath("../test-files/knn-dict").resolve("knn-token-vectors");
Path dictPath = createTempDir("knn-demo").resolve("dict");
KnnVectorDict.build(testVectors, dictPath);
try (KnnVectorDict dict = new KnnVectorDict(dictPath)) {
DemoEmbeddings demoEmbeddings = new DemoEmbeddings(dict);
try (Directory directory = newDirectory()) {
KnnVectorDict.build(testVectors, directory, "dict");
try (KnnVectorDict dict = new KnnVectorDict(directory, "dict")) {
DemoEmbeddings demoEmbeddings = new DemoEmbeddings(dict);
// test garbage
float[] garbageVector =
demoEmbeddings.computeEmbedding("garbagethathasneverbeen seeneverinlife");
assertEquals(50, garbageVector.length);
assertArrayEquals(new float[50], garbageVector, 0);
// test garbage
float[] garbageVector =
demoEmbeddings.computeEmbedding("garbagethathasneverbeen seeneverinlife");
assertEquals(50, garbageVector.length);
assertArrayEquals(new float[50], garbageVector, 0);
// test space
assertArrayEquals(new float[50], demoEmbeddings.computeEmbedding(" "), 0);
// test space
assertArrayEquals(new float[50], demoEmbeddings.computeEmbedding(" "), 0);
// test some real words that are in the dictionary and some that are not
float[] realVector = demoEmbeddings.computeEmbedding("the real fact");
assertEquals(50, realVector.length);
// test some real words that are in the dictionary and some that are not
float[] realVector = demoEmbeddings.computeEmbedding("the real fact");
assertEquals(50, realVector.length);
float[] the = getTermVector(dict, "the");
assertArrayEquals(new float[50], getTermVector(dict, "real"), 0);
float[] fact = getTermVector(dict, "fact");
VectorUtil.add(the, fact);
VectorUtil.l2normalize(the);
assertArrayEquals(the, realVector, 0);
float[] the = getTermVector(dict, "the");
assertArrayEquals(new float[50], getTermVector(dict, "real"), 0);
float[] fact = getTermVector(dict, "fact");
VectorUtil.add(the, fact);
VectorUtil.l2normalize(the);
assertArrayEquals(the, realVector, 0);
}
}
}

View File

@ -19,6 +19,7 @@ package org.apache.lucene.demo.knn;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
@ -26,23 +27,25 @@ public class TestKnnVectorDict extends LuceneTestCase {
public void testBuild() throws IOException {
Path testVectors = getDataPath("../test-files/knn-dict").resolve("knn-token-vectors");
Path dictPath = createTempDir("knn-demo").resolve("dict");
KnnVectorDict.build(testVectors, dictPath);
try (KnnVectorDict dict = new KnnVectorDict(dictPath)) {
assertEquals(50, dict.getDimension());
byte[] vector = new byte[dict.getDimension() * Float.BYTES];
// not found token has zero vector
dict.get(new BytesRef("never saw this token"), vector);
assertArrayEquals(new byte[200], vector);
try (Directory directory = newDirectory()) {
KnnVectorDict.build(testVectors, directory, "dict");
try (KnnVectorDict dict = new KnnVectorDict(directory, "dict")) {
assertEquals(50, dict.getDimension());
byte[] vector = new byte[dict.getDimension() * Float.BYTES];
// found token has nonzero vector
dict.get(new BytesRef("the"), vector);
assertFalse(Arrays.equals(new byte[200], vector));
// not found token has zero vector
dict.get(new BytesRef("never saw this token"), vector);
assertArrayEquals(new byte[200], vector);
// incorrect dimension for output buffer
expectThrows(
IllegalArgumentException.class, () -> dict.get(new BytesRef("the"), new byte[10]));
// found token has nonzero vector
dict.get(new BytesRef("the"), vector);
assertFalse(Arrays.equals(new byte[200], vector));
// incorrect dimension for output buffer
expectThrows(
IllegalArgumentException.class, () -> dict.get(new BytesRef("the"), new byte[10]));
}
}
}
}