Introduced the Word2VecSynonymFilter (#12169)

Co-authored-by: Alessandro Benedetti <a.benedetti@sease.io>
This commit is contained in:
Daniele Antuzi 2023-04-24 13:35:26 +02:00 committed by GitHub
parent 5e0761eab5
commit 1f4f2bf509
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1450 additions and 23 deletions

View File

@ -135,6 +135,8 @@ New Features
crash the JVM. To disable this feature, pass the following sysprop on Java command line:
"-Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false" (Uwe Schindler)
* GITHUB#12169: Introduce a new token filter to expand synonyms based on Word2Vec DL4j models. (Daniele Antuzi, Ilaria Petreti, Alessandro Benedetti)
Improvements
---------------------

View File

@ -89,6 +89,8 @@ import org.apache.lucene.analysis.shingle.ShingleFilter;
import org.apache.lucene.analysis.standard.StandardTokenizer;
import org.apache.lucene.analysis.stempel.StempelStemmer;
import org.apache.lucene.analysis.synonym.SynonymMap;
import org.apache.lucene.analysis.synonym.word2vec.Word2VecModel;
import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProvider;
import org.apache.lucene.store.ByteBuffersDirectory;
import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase;
import org.apache.lucene.tests.analysis.MockTokenFilter;
@ -99,8 +101,10 @@ import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.tests.util.automaton.AutomatonTestUtil;
import org.apache.lucene.util.AttributeFactory;
import org.apache.lucene.util.AttributeSource;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.IgnoreRandomChains;
import org.apache.lucene.util.TermAndVector;
import org.apache.lucene.util.Version;
import org.apache.lucene.util.automaton.Automaton;
import org.apache.lucene.util.automaton.CharacterRunAutomaton;
@ -415,6 +419,27 @@ public class TestRandomChains extends BaseTokenStreamTestCase {
}
}
});
put(
Word2VecSynonymProvider.class,
random -> {
final int numEntries = atLeast(10);
final int vectorDimension = random.nextInt(99) + 1;
Word2VecModel model = new Word2VecModel(numEntries, vectorDimension);
for (int j = 0; j < numEntries; j++) {
String s = TestUtil.randomSimpleString(random, 10, 20);
float[] vec = new float[vectorDimension];
for (int i = 0; i < vectorDimension; i++) {
vec[i] = random.nextFloat();
}
model.addTermAndVector(new TermAndVector(new BytesRef(s), vec));
}
try {
return new Word2VecSynonymProvider(model);
} catch (IOException e) {
Rethrow.rethrow(e);
return null; // unreachable code
}
});
put(
DateFormat.class,
random -> {

View File

@ -79,6 +79,7 @@ module org.apache.lucene.analysis.common {
exports org.apache.lucene.analysis.sr;
exports org.apache.lucene.analysis.sv;
exports org.apache.lucene.analysis.synonym;
exports org.apache.lucene.analysis.synonym.word2vec;
exports org.apache.lucene.analysis.ta;
exports org.apache.lucene.analysis.te;
exports org.apache.lucene.analysis.th;
@ -257,6 +258,7 @@ module org.apache.lucene.analysis.common {
org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory,
org.apache.lucene.analysis.synonym.SynonymFilterFactory,
org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory,
org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory,
org.apache.lucene.analysis.core.FlattenGraphFilterFactory,
org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory,
org.apache.lucene.analysis.te.TeluguStemFilterFactory,

View File

@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.synonym.word2vec;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Locale;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TermAndVector;
/**
* Dl4jModelReader reads the file generated by the library Deeplearning4j and provide a
* Word2VecModel with normalized vectors
*
* <p>Dl4j Word2Vec documentation:
* https://deeplearning4j.konduit.ai/v/en-1.0.0-beta7/language-processing/word2vec Example to
* generate a model using dl4j:
* https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/embeddingsfromcorpus/word2vec/Word2VecRawTextExample.java
*
* @lucene.experimental
*/
public class Dl4jModelReader implements Closeable {
private static final String MODEL_FILE_NAME_PREFIX = "syn0";
private final ZipInputStream word2VecModelZipFile;
public Dl4jModelReader(InputStream stream) {
this.word2VecModelZipFile = new ZipInputStream(new BufferedInputStream(stream));
}
public Word2VecModel read() throws IOException {
ZipEntry entry;
while ((entry = word2VecModelZipFile.getNextEntry()) != null) {
String fileName = entry.getName();
if (fileName.startsWith(MODEL_FILE_NAME_PREFIX)) {
BufferedReader reader =
new BufferedReader(new InputStreamReader(word2VecModelZipFile, StandardCharsets.UTF_8));
String header = reader.readLine();
String[] headerValues = header.split(" ");
int dictionarySize = Integer.parseInt(headerValues[0]);
int vectorDimension = Integer.parseInt(headerValues[1]);
Word2VecModel model = new Word2VecModel(dictionarySize, vectorDimension);
String line = reader.readLine();
boolean isTermB64Encoded = false;
if (line != null) {
String[] tokens = line.split(" ");
isTermB64Encoded =
tokens[0].substring(0, 3).toLowerCase(Locale.ROOT).compareTo("b64") == 0;
model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded));
}
while ((line = reader.readLine()) != null) {
String[] tokens = line.split(" ");
model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded));
}
return model;
}
}
throw new IllegalArgumentException(
"Cannot read Dl4j word2vec model - '"
+ MODEL_FILE_NAME_PREFIX
+ "' file is missing in the zip. '"
+ MODEL_FILE_NAME_PREFIX
+ "' is a mandatory file containing the mapping between terms and vectors generated by the DL4j library.");
}
private static TermAndVector extractTermAndVector(
String[] tokens, int vectorDimension, boolean isTermB64Encoded) {
BytesRef term = isTermB64Encoded ? decodeB64Term(tokens[0]) : new BytesRef((tokens[0]));
float[] vector = new float[tokens.length - 1];
if (vectorDimension != vector.length) {
throw new RuntimeException(
String.format(
Locale.ROOT,
"Word2Vec model file corrupted. "
+ "Declared vectors of size %d but found vector of size %d for word %s (%s)",
vectorDimension,
vector.length,
tokens[0],
term.utf8ToString()));
}
for (int i = 1; i < tokens.length; i++) {
vector[i - 1] = Float.parseFloat(tokens[i]);
}
return new TermAndVector(term, vector);
}
static BytesRef decodeB64Term(String term) {
byte[] buffer = Base64.getDecoder().decode(term.substring(4));
return new BytesRef(buffer, 0, buffer.length);
}
@Override
public void close() throws IOException {
word2VecModelZipFile.close();
}
}

View File

@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.synonym.word2vec;
import java.io.IOException;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefHash;
import org.apache.lucene.util.TermAndVector;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* Word2VecModel is a class representing the parsed Word2Vec model containing the vectors for each
* word in dictionary
*
* @lucene.experimental
*/
public class Word2VecModel implements RandomAccessVectorValues<float[]> {
private final int dictionarySize;
private final int vectorDimension;
private final TermAndVector[] termsAndVectors;
private final BytesRefHash word2Vec;
private int loadedCount = 0;
public Word2VecModel(int dictionarySize, int vectorDimension) {
this.dictionarySize = dictionarySize;
this.vectorDimension = vectorDimension;
this.termsAndVectors = new TermAndVector[dictionarySize];
this.word2Vec = new BytesRefHash();
}
private Word2VecModel(
int dictionarySize,
int vectorDimension,
TermAndVector[] termsAndVectors,
BytesRefHash word2Vec) {
this.dictionarySize = dictionarySize;
this.vectorDimension = vectorDimension;
this.termsAndVectors = termsAndVectors;
this.word2Vec = word2Vec;
}
public void addTermAndVector(TermAndVector modelEntry) {
modelEntry.normalizeVector();
this.termsAndVectors[loadedCount++] = modelEntry;
this.word2Vec.add(modelEntry.getTerm());
}
@Override
public float[] vectorValue(int targetOrd) {
return termsAndVectors[targetOrd].getVector();
}
public float[] vectorValue(BytesRef term) {
int termOrd = this.word2Vec.find(term);
if (termOrd < 0) return null;
TermAndVector entry = this.termsAndVectors[termOrd];
return (entry == null) ? null : entry.getVector();
}
public BytesRef termValue(int targetOrd) {
return termsAndVectors[targetOrd].getTerm();
}
@Override
public int dimension() {
return vectorDimension;
}
@Override
public int size() {
return dictionarySize;
}
@Override
public RandomAccessVectorValues<float[]> copy() throws IOException {
return new Word2VecModel(
this.dictionarySize, this.vectorDimension, this.termsAndVectors, this.word2Vec);
}
}

View File

@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.synonym.word2vec;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.synonym.SynonymGraphFilter;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute;
import org.apache.lucene.analysis.tokenattributes.TypeAttribute;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.TermAndBoost;
/**
* Applies single-token synonyms from a Word2Vec trained network to an incoming {@link TokenStream}.
*
* @lucene.experimental
*/
public final class Word2VecSynonymFilter extends TokenFilter {
private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
private final PositionIncrementAttribute posIncrementAtt =
addAttribute(PositionIncrementAttribute.class);
private final PositionLengthAttribute posLenAtt = addAttribute(PositionLengthAttribute.class);
private final TypeAttribute typeAtt = addAttribute(TypeAttribute.class);
private final Word2VecSynonymProvider synonymProvider;
private final int maxSynonymsPerTerm;
private final float minAcceptedSimilarity;
private final LinkedList<TermAndBoost> synonymBuffer = new LinkedList<>();
private State lastState;
/**
* Apply previously built synonymProvider to incoming tokens.
*
* @param input input tokenstream
* @param synonymProvider synonym provider
* @param maxSynonymsPerTerm maximum number of result returned by the synonym search
* @param minAcceptedSimilarity minimal value of cosine similarity between the searched vector and
* the retrieved ones
*/
public Word2VecSynonymFilter(
TokenStream input,
Word2VecSynonymProvider synonymProvider,
int maxSynonymsPerTerm,
float minAcceptedSimilarity) {
super(input);
this.synonymProvider = synonymProvider;
this.maxSynonymsPerTerm = maxSynonymsPerTerm;
this.minAcceptedSimilarity = minAcceptedSimilarity;
}
@Override
public boolean incrementToken() throws IOException {
if (!synonymBuffer.isEmpty()) {
TermAndBoost synonym = synonymBuffer.pollFirst();
clearAttributes();
restoreState(this.lastState);
termAtt.setEmpty();
termAtt.append(synonym.term.utf8ToString());
typeAtt.setType(SynonymGraphFilter.TYPE_SYNONYM);
posLenAtt.setPositionLength(1);
posIncrementAtt.setPositionIncrement(0);
return true;
}
if (input.incrementToken()) {
BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
bytesRefBuilder.copyChars(termAtt.buffer(), 0, termAtt.length());
BytesRef term = bytesRefBuilder.get();
List<TermAndBoost> synonyms =
this.synonymProvider.getSynonyms(term, maxSynonymsPerTerm, minAcceptedSimilarity);
if (synonyms.size() > 0) {
this.lastState = captureState();
this.synonymBuffer.addAll(synonyms);
}
return true;
}
return false;
}
@Override
public void reset() throws IOException {
super.reset();
synonymBuffer.clear();
}
}

View File

@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.synonym.word2vec;
import java.io.IOException;
import java.util.Locale;
import java.util.Map;
import org.apache.lucene.analysis.TokenFilterFactory;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProviderFactory.Word2VecSupportedFormats;
import org.apache.lucene.util.ResourceLoader;
import org.apache.lucene.util.ResourceLoaderAware;
/**
* Factory for {@link Word2VecSynonymFilter}.
*
* @lucene.experimental
* @lucene.spi {@value #NAME}
*/
public class Word2VecSynonymFilterFactory extends TokenFilterFactory
implements ResourceLoaderAware {
/** SPI name */
public static final String NAME = "Word2VecSynonym";
public static final int DEFAULT_MAX_SYNONYMS_PER_TERM = 5;
public static final float DEFAULT_MIN_ACCEPTED_SIMILARITY = 0.8f;
private final int maxSynonymsPerTerm;
private final float minAcceptedSimilarity;
private final Word2VecSupportedFormats format;
private final String word2vecModelFileName;
private Word2VecSynonymProvider synonymProvider;
public Word2VecSynonymFilterFactory(Map<String, String> args) {
super(args);
this.maxSynonymsPerTerm = getInt(args, "maxSynonymsPerTerm", DEFAULT_MAX_SYNONYMS_PER_TERM);
this.minAcceptedSimilarity =
getFloat(args, "minAcceptedSimilarity", DEFAULT_MIN_ACCEPTED_SIMILARITY);
this.word2vecModelFileName = require(args, "model");
String modelFormat = get(args, "format", "dl4j").toUpperCase(Locale.ROOT);
try {
this.format = Word2VecSupportedFormats.valueOf(modelFormat);
} catch (IllegalArgumentException exc) {
throw new IllegalArgumentException("Model format '" + modelFormat + "' not supported", exc);
}
if (!args.isEmpty()) {
throw new IllegalArgumentException("Unknown parameters: " + args);
}
if (minAcceptedSimilarity <= 0 || minAcceptedSimilarity > 1) {
throw new IllegalArgumentException(
"minAcceptedSimilarity must be in the range (0, 1]. Found: " + minAcceptedSimilarity);
}
if (maxSynonymsPerTerm <= 0) {
throw new IllegalArgumentException(
"maxSynonymsPerTerm must be a positive integer greater than 0. Found: "
+ maxSynonymsPerTerm);
}
}
/** Default ctor for compatibility with SPI */
public Word2VecSynonymFilterFactory() {
throw defaultCtorException();
}
Word2VecSynonymProvider getSynonymProvider() {
return this.synonymProvider;
}
@Override
public TokenStream create(TokenStream input) {
return synonymProvider == null
? input
: new Word2VecSynonymFilter(
input, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity);
}
@Override
public void inform(ResourceLoader loader) throws IOException {
this.synonymProvider =
Word2VecSynonymProviderFactory.getSynonymProvider(loader, word2vecModelFileName, format);
}
}

View File

@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.synonym.word2vec;
import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_MAX_CONN;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TermAndBoost;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.NeighborQueue;
/**
* The Word2VecSynonymProvider generates the list of sysnonyms of a term.
*
* @lucene.experimental
*/
public class Word2VecSynonymProvider {
private static final VectorSimilarityFunction SIMILARITY_FUNCTION =
VectorSimilarityFunction.DOT_PRODUCT;
private static final VectorEncoding VECTOR_ENCODING = VectorEncoding.FLOAT32;
private final Word2VecModel word2VecModel;
private final HnswGraph hnswGraph;
/**
* Word2VecSynonymProvider constructor
*
* @param model containing the set of TermAndVector entries
*/
public Word2VecSynonymProvider(Word2VecModel model) throws IOException {
word2VecModel = model;
HnswGraphBuilder<float[]> builder =
HnswGraphBuilder.create(
word2VecModel,
VECTOR_ENCODING,
SIMILARITY_FUNCTION,
DEFAULT_MAX_CONN,
DEFAULT_BEAM_WIDTH,
HnswGraphBuilder.randSeed);
this.hnswGraph = builder.build(word2VecModel.copy());
}
public List<TermAndBoost> getSynonyms(
BytesRef term, int maxSynonymsPerTerm, float minAcceptedSimilarity) throws IOException {
if (term == null) {
throw new IllegalArgumentException("Term must not be null");
}
LinkedList<TermAndBoost> result = new LinkedList<>();
float[] query = word2VecModel.vectorValue(term);
if (query != null) {
NeighborQueue synonyms =
HnswGraphSearcher.search(
query,
// The query vector is in the model. When looking for the top-k
// it's always the nearest neighbour of itself so, we look for the top-k+1
maxSynonymsPerTerm + 1,
word2VecModel,
VECTOR_ENCODING,
SIMILARITY_FUNCTION,
hnswGraph,
null,
word2VecModel.size());
int size = synonyms.size();
for (int i = 0; i < size; i++) {
float similarity = synonyms.topScore();
int id = synonyms.pop();
BytesRef synonym = word2VecModel.termValue(id);
// We remove the original query term
if (!synonym.equals(term) && similarity >= minAcceptedSimilarity) {
result.addFirst(new TermAndBoost(synonym, similarity));
}
}
}
return result;
}
}

View File

@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.synonym.word2vec;
import java.io.IOException;
import java.io.InputStream;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.lucene.util.ResourceLoader;
/**
* Supply Word2Vec Word2VecSynonymProvider cache avoiding that multiple instances of
* Word2VecSynonymFilterFactory will instantiate multiple instances of the same SynonymProvider.
* Assumes synonymProvider implementations are thread-safe.
*/
public class Word2VecSynonymProviderFactory {
enum Word2VecSupportedFormats {
DL4J
}
private static Map<String, Word2VecSynonymProvider> word2vecSynonymProviders =
new ConcurrentHashMap<>();
public static Word2VecSynonymProvider getSynonymProvider(
ResourceLoader loader, String modelFileName, Word2VecSupportedFormats format)
throws IOException {
Word2VecSynonymProvider synonymProvider = word2vecSynonymProviders.get(modelFileName);
if (synonymProvider == null) {
try (InputStream stream = loader.openResource(modelFileName)) {
try (Dl4jModelReader reader = getModelReader(format, stream)) {
synonymProvider = new Word2VecSynonymProvider(reader.read());
}
}
word2vecSynonymProviders.put(modelFileName, synonymProvider);
}
return synonymProvider;
}
private static Dl4jModelReader getModelReader(
Word2VecSupportedFormats format, InputStream stream) {
switch (format) {
case DL4J:
return new Dl4jModelReader(stream);
}
return null;
}
}

View File

@ -0,0 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/** Analysis components for Synonyms using Word2Vec model. */
package org.apache.lucene.analysis.synonym.word2vec;

View File

@ -118,6 +118,7 @@ org.apache.lucene.analysis.sv.SwedishLightStemFilterFactory
org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory
org.apache.lucene.analysis.synonym.SynonymFilterFactory
org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory
org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory
org.apache.lucene.analysis.core.FlattenGraphFilterFactory
org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory
org.apache.lucene.analysis.te.TeluguStemFilterFactory

View File

@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.synonym.word2vec;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef;
import org.junit.Test;
public class TestDl4jModelReader extends LuceneTestCase {
private static final String MODEL_FILE = "word2vec-model.zip";
private static final String MODEL_EMPTY_FILE = "word2vec-empty-model.zip";
private static final String CORRUPTED_VECTOR_DIMENSION_MODEL_FILE =
"word2vec-corrupted-vector-dimension-model.zip";
InputStream stream = TestDl4jModelReader.class.getResourceAsStream(MODEL_FILE);
Dl4jModelReader unit = new Dl4jModelReader(stream);
@Test
public void read_zipFileWithMetadata_shouldReturnDictionarySize() throws Exception {
Word2VecModel model = unit.read();
long expectedDictionarySize = 235;
assertEquals(expectedDictionarySize, model.size());
}
@Test
public void read_zipFileWithMetadata_shouldReturnVectorLength() throws Exception {
Word2VecModel model = unit.read();
int expectedVectorDimension = 100;
assertEquals(expectedVectorDimension, model.dimension());
}
@Test
public void read_zipFile_shouldReturnDecodedTerm() throws Exception {
Word2VecModel model = unit.read();
BytesRef expectedDecodedFirstTerm = new BytesRef("it");
assertEquals(expectedDecodedFirstTerm, model.termValue(0));
}
@Test
public void decodeTerm_encodedTerm_shouldReturnDecodedTerm() throws Exception {
byte[] originalInput = "lucene".getBytes(StandardCharsets.UTF_8);
String B64encodedLuceneTerm = Base64.getEncoder().encodeToString(originalInput);
String word2vecEncodedLuceneTerm = "B64:" + B64encodedLuceneTerm;
assertEquals(new BytesRef("lucene"), Dl4jModelReader.decodeB64Term(word2vecEncodedLuceneTerm));
}
@Test
public void read_EmptyZipFile_shouldThrowException() throws Exception {
try (InputStream stream = TestDl4jModelReader.class.getResourceAsStream(MODEL_EMPTY_FILE)) {
Dl4jModelReader unit = new Dl4jModelReader(stream);
expectThrows(IllegalArgumentException.class, unit::read);
}
}
@Test
public void read_corruptedVectorDimensionModelFile_shouldThrowException() throws Exception {
try (InputStream stream =
TestDl4jModelReader.class.getResourceAsStream(CORRUPTED_VECTOR_DIMENSION_MODEL_FILE)) {
Dl4jModelReader unit = new Dl4jModelReader(stream);
expectThrows(RuntimeException.class, unit::read);
}
}
}

View File

@ -0,0 +1,152 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.synonym.word2vec;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase;
import org.apache.lucene.tests.analysis.MockTokenizer;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TermAndVector;
import org.junit.Test;
public class TestWord2VecSynonymFilter extends BaseTokenStreamTestCase {
@Test
public void synonymExpansion_oneCandidate_shouldBeExpandedWithinThreshold() throws Exception {
int maxSynonymPerTerm = 10;
float minAcceptedSimilarity = 0.9f;
Word2VecModel model = new Word2VecModel(6, 2);
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10}));
Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity);
assertAnalyzesTo(
a,
"pre a post", // input
new String[] {"pre", "a", "d", "e", "c", "b", "post"}, // output
new int[] {0, 4, 4, 4, 4, 4, 6}, // start offset
new int[] {3, 5, 5, 5, 5, 5, 10}, // end offset
new String[] {"word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word"}, // types
new int[] {1, 1, 0, 0, 0, 0, 1}, // posIncrements
new int[] {1, 1, 1, 1, 1, 1, 1}); // posLenghts
a.close();
}
@Test
public void synonymExpansion_oneCandidate_shouldBeExpandedWithTopKSynonyms() throws Exception {
int maxSynonymPerTerm = 2;
float minAcceptedSimilarity = 0.9f;
Word2VecModel model = new Word2VecModel(5, 2);
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity);
assertAnalyzesTo(
a,
"pre a post", // input
new String[] {"pre", "a", "d", "e", "post"}, // output
new int[] {0, 4, 4, 4, 6}, // start offset
new int[] {3, 5, 5, 5, 10}, // end offset
new String[] {"word", "word", "SYNONYM", "SYNONYM", "word"}, // types
new int[] {1, 1, 0, 0, 1}, // posIncrements
new int[] {1, 1, 1, 1, 1}); // posLenghts
a.close();
}
@Test
public void synonymExpansion_twoCandidates_shouldBothBeExpanded() throws Exception {
Word2VecModel model = new Word2VecModel(8, 2);
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {1, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("post"), new float[] {-10, -11}));
model.addTermAndVector(new TermAndVector(new BytesRef("after"), new float[] {-8, -10}));
Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
Analyzer a = getAnalyzer(synonymProvider, 10, 0.9f);
assertAnalyzesTo(
a,
"pre a post", // input
new String[] {"pre", "a", "d", "e", "c", "b", "post", "after"}, // output
new int[] {0, 4, 4, 4, 4, 4, 6, 6}, // start offset
new int[] {3, 5, 5, 5, 5, 5, 10, 10}, // end offset
new String[] { // types
"word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word", "SYNONYM"
},
new int[] {1, 1, 0, 0, 0, 0, 1, 0}, // posIncrements
new int[] {1, 1, 1, 1, 1, 1, 1, 1}); // posLengths
a.close();
}
@Test
public void synonymExpansion_forMinAcceptedSimilarity_shouldExpandToNoneSynonyms()
throws Exception {
Word2VecModel model = new Word2VecModel(4, 2);
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {-10, -8}));
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {-9, -10}));
model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, -10}));
Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
Analyzer a = getAnalyzer(synonymProvider, 10, 0.8f);
assertAnalyzesTo(
a,
"pre a post", // input
new String[] {"pre", "a", "post"}, // output
new int[] {0, 4, 6}, // start offset
new int[] {3, 5, 10}, // end offset
new String[] {"word", "word", "word"}, // types
new int[] {1, 1, 1}, // posIncrements
new int[] {1, 1, 1}); // posLengths
a.close();
}
private Analyzer getAnalyzer(
Word2VecSynonymProvider synonymProvider,
int maxSynonymsPerTerm,
float minAcceptedSimilarity) {
return new Analyzer() {
@Override
protected TokenStreamComponents createComponents(String fieldName) {
Tokenizer tokenizer = new MockTokenizer(MockTokenizer.WHITESPACE, false);
// Make a local variable so testRandomHuge doesn't share it across threads!
Word2VecSynonymFilter synFilter =
new Word2VecSynonymFilter(
tokenizer, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity);
return new TokenStreamComponents(tokenizer, synFilter);
}
};
}
}

View File

@ -0,0 +1,159 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.synonym.word2vec;
import org.apache.lucene.tests.analysis.BaseTokenStreamFactoryTestCase;
import org.apache.lucene.util.ClasspathResourceLoader;
import org.apache.lucene.util.ResourceLoader;
import org.junit.Test;
public class TestWord2VecSynonymFilterFactory extends BaseTokenStreamFactoryTestCase {
public static final String FACTORY_NAME = "Word2VecSynonym";
private static final String WORD2VEC_MODEL_FILE = "word2vec-model.zip";
@Test
public void testInform() throws Exception {
ResourceLoader loader = new ClasspathResourceLoader(getClass());
assertTrue("loader is null and it shouldn't be", loader != null);
Word2VecSynonymFilterFactory factory =
(Word2VecSynonymFilterFactory)
tokenFilterFactory(
FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "minAcceptedSimilarity", "0.7");
Word2VecSynonymProvider synonymProvider = factory.getSynonymProvider();
assertNotEquals(null, synonymProvider);
}
@Test
public void missingRequiredArgument_shouldThrowException() throws Exception {
IllegalArgumentException expected =
expectThrows(
IllegalArgumentException.class,
() -> {
tokenFilterFactory(
FACTORY_NAME,
"format",
"dl4j",
"minAcceptedSimilarity",
"0.7",
"maxSynonymsPerTerm",
"10");
});
assertTrue(expected.getMessage().contains("Configuration Error: missing parameter 'model'"));
}
@Test
public void unsupportedModelFormat_shouldThrowException() throws Exception {
IllegalArgumentException expected =
expectThrows(
IllegalArgumentException.class,
() -> {
tokenFilterFactory(
FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "format", "bogusValue");
});
assertTrue(expected.getMessage().contains("Model format 'BOGUSVALUE' not supported"));
}
@Test
public void bogusArgument_shouldThrowException() throws Exception {
IllegalArgumentException expected =
expectThrows(
IllegalArgumentException.class,
() -> {
tokenFilterFactory(
FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "bogusArg", "bogusValue");
});
assertTrue(expected.getMessage().contains("Unknown parameters"));
}
@Test
public void illegalArguments_shouldThrowException() throws Exception {
IllegalArgumentException expected =
expectThrows(
IllegalArgumentException.class,
() -> {
tokenFilterFactory(
FACTORY_NAME,
"model",
WORD2VEC_MODEL_FILE,
"minAcceptedSimilarity",
"2",
"maxSynonymsPerTerm",
"10");
});
assertTrue(
expected
.getMessage()
.contains("minAcceptedSimilarity must be in the range (0, 1]. Found: 2"));
expected =
expectThrows(
IllegalArgumentException.class,
() -> {
tokenFilterFactory(
FACTORY_NAME,
"model",
WORD2VEC_MODEL_FILE,
"minAcceptedSimilarity",
"0",
"maxSynonymsPerTerm",
"10");
});
assertTrue(
expected
.getMessage()
.contains("minAcceptedSimilarity must be in the range (0, 1]. Found: 0"));
expected =
expectThrows(
IllegalArgumentException.class,
() -> {
tokenFilterFactory(
FACTORY_NAME,
"model",
WORD2VEC_MODEL_FILE,
"minAcceptedSimilarity",
"0.7",
"maxSynonymsPerTerm",
"-1");
});
assertTrue(
expected
.getMessage()
.contains("maxSynonymsPerTerm must be a positive integer greater than 0. Found: -1"));
expected =
expectThrows(
IllegalArgumentException.class,
() -> {
tokenFilterFactory(
FACTORY_NAME,
"model",
WORD2VEC_MODEL_FILE,
"minAcceptedSimilarity",
"0.7",
"maxSynonymsPerTerm",
"0");
});
assertTrue(
expected
.getMessage()
.contains("maxSynonymsPerTerm must be a positive integer greater than 0. Found: 0"));
}
}

View File

@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.analysis.synonym.word2vec;
import java.io.IOException;
import java.util.List;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TermAndBoost;
import org.apache.lucene.util.TermAndVector;
import org.junit.Test;
public class TestWord2VecSynonymProvider extends LuceneTestCase {
private static final int MAX_SYNONYMS_PER_TERM = 10;
private static final float MIN_ACCEPTED_SIMILARITY = 0.85f;
private final Word2VecSynonymProvider unit;
public TestWord2VecSynonymProvider() throws IOException {
Word2VecModel model = new Word2VecModel(2, 3);
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {0.24f, 0.78f, 0.28f}));
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {0.44f, 0.01f, 0.81f}));
unit = new Word2VecSynonymProvider(model);
}
@Test
public void getSynonyms_nullToken_shouldThrowException() {
expectThrows(
IllegalArgumentException.class,
() -> unit.getSynonyms(null, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY));
}
@Test
public void getSynonyms_shouldReturnSynonymsBasedOnMinAcceptedSimilarity() throws Exception {
Word2VecModel model = new Word2VecModel(6, 2);
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10}));
Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model);
BytesRef inputTerm = new BytesRef("a");
String[] expectedSynonyms = {"d", "e", "c", "b"};
List<TermAndBoost> actualSynonymsResults =
unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY);
assertEquals(4, actualSynonymsResults.size());
for (int i = 0; i < expectedSynonyms.length; i++) {
assertEquals(new BytesRef(expectedSynonyms[i]), actualSynonymsResults.get(i).term);
}
}
@Test
public void getSynonyms_shouldReturnSynonymsBoost() throws Exception {
Word2VecModel model = new Word2VecModel(3, 2);
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {1, 1}));
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {99, 101}));
Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model);
BytesRef inputTerm = new BytesRef("a");
List<TermAndBoost> actualSynonymsResults =
unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY);
BytesRef expectedFirstSynonymTerm = new BytesRef("b");
double expectedFirstSynonymBoost = 1.0;
assertEquals(expectedFirstSynonymTerm, actualSynonymsResults.get(0).term);
assertEquals(expectedFirstSynonymBoost, actualSynonymsResults.get(0).boost, 0.001f);
}
@Test
public void noSynonymsWithinAcceptedSimilarity_shouldReturnNoSynonyms() throws Exception {
Word2VecModel model = new Word2VecModel(4, 2);
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {-10, -8}));
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {-9, -10}));
model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {6, -6}));
Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model);
BytesRef inputTerm = newBytesRef("a");
List<TermAndBoost> actualSynonymsResults =
unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY);
assertEquals(0, actualSynonymsResults.size());
}
@Test
public void testModel_shouldReturnNormalizedVectors() {
Word2VecModel model = new Word2VecModel(4, 2);
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10}));
float[] vectorIdA = model.vectorValue(new BytesRef("a"));
float[] vectorIdF = model.vectorValue(new BytesRef("f"));
assertArrayEquals(new float[] {0.70710f, 0.70710f}, vectorIdA, 0.001f);
assertArrayEquals(new float[] {-0.0995f, 0.99503f}, vectorIdF, 0.001f);
}
@Test
public void normalizedVector_shouldReturnModule1() {
TermAndVector synonymTerm = new TermAndVector(new BytesRef("a"), new float[] {10, 10});
synonymTerm.normalizeVector();
float[] vector = synonymTerm.getVector();
float len = 0;
for (int i = 0; i < vector.length; i++) {
len += vector[i] * vector[i];
}
assertEquals(1, Math.sqrt(len), 0.0001f);
}
}

View File

@ -62,20 +62,6 @@ public class QueryBuilder {
protected boolean enableGraphQueries = true;
protected boolean autoGenerateMultiTermSynonymsPhraseQuery = false;
/** Wraps a term and boost */
public static class TermAndBoost {
/** the term */
public final BytesRef term;
/** the boost */
public final float boost;
/** Creates a new TermAndBoost */
public TermAndBoost(BytesRef term, float boost) {
this.term = BytesRef.deepCopyOf(term);
this.boost = boost;
}
}
/** Creates a new QueryBuilder using the given analyzer. */
public QueryBuilder(Analyzer analyzer) {
this.analyzer = analyzer;

View File

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util;
/** Wraps a term and boost */
public class TermAndBoost {
/** the term */
public final BytesRef term;
/** the boost */
public final float boost;
/** Creates a new TermAndBoost */
public TermAndBoost(BytesRef term, float boost) {
this.term = BytesRef.deepCopyOf(term);
this.boost = boost;
}
}

View File

@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util;
import java.util.Locale;
/**
* Word2Vec unit composed by a term with the associated vector
*
* @lucene.experimental
*/
public class TermAndVector {
private final BytesRef term;
private final float[] vector;
public TermAndVector(BytesRef term, float[] vector) {
this.term = term;
this.vector = vector;
}
public BytesRef getTerm() {
return this.term;
}
public float[] getVector() {
return this.vector;
}
public int size() {
return vector.length;
}
public void normalizeVector() {
float vectorLength = 0;
for (int i = 0; i < vector.length; i++) {
vectorLength += vector[i] * vector[i];
}
vectorLength = (float) Math.sqrt(vectorLength);
for (int i = 0; i < vector.length; i++) {
vector[i] /= vectorLength;
}
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder(this.term.utf8ToString());
builder.append(" [");
if (vector.length > 0) {
for (int i = 0; i < vector.length - 1; i++) {
builder.append(String.format(Locale.ROOT, "%.3f,", vector[i]));
}
builder.append(String.format(Locale.ROOT, "%.3f]", vector[vector.length - 1]));
}
return builder.toString();
}
}

View File

@ -41,8 +41,17 @@ import org.apache.lucene.util.InfoStream;
*/
public final class HnswGraphBuilder<T> {
/** Default number of maximum connections per node */
public static final int DEFAULT_MAX_CONN = 16;
/**
* Default number of the size of the queue maintained while searching during a graph construction.
*/
public static final int DEFAULT_BEAM_WIDTH = 100;
/** Default random seed for level generation * */
private static final long DEFAULT_RAND_SEED = 42;
/** A name for the HNSW component for the info-stream * */
public static final String HNSW_COMPONENT = "HNSW";

View File

@ -54,6 +54,7 @@ import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.junit.After;
import org.junit.Before;
@ -62,7 +63,7 @@ public class TestKnnGraph extends LuceneTestCase {
private static final String KNN_GRAPH_FIELD = "vector";
private static int M = Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN;
private static int M = HnswGraphBuilder.DEFAULT_MAX_CONN;
private Codec codec;
private Codec float32Codec;
@ -80,7 +81,7 @@ public class TestKnnGraph extends LuceneTestCase {
new Lucene95Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
}
};
@ -92,7 +93,7 @@ public class TestKnnGraph extends LuceneTestCase {
new Lucene95Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
}
};
@ -103,7 +104,7 @@ public class TestKnnGraph extends LuceneTestCase {
new Lucene95Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
}
};
}
@ -115,7 +116,7 @@ public class TestKnnGraph extends LuceneTestCase {
@After
public void cleanup() {
M = Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN;
M = HnswGraphBuilder.DEFAULT_MAX_CONN;
}
/** Basic test of creating documents in a graph */

View File

@ -55,6 +55,7 @@ import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.search.BoostAttribute;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
@ -154,7 +155,8 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
boolean[] keywordAtts,
boolean graphOffsetsAreCorrect,
byte[][] payloads,
int[] flags)
int[] flags,
float[] boost)
throws IOException {
assertNotNull(output);
CheckClearAttributesAttribute checkClearAtt =
@ -221,6 +223,12 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
flagsAtt = ts.getAttribute(FlagsAttribute.class);
}
BoostAttribute boostAtt = null;
if (boost != null) {
assertTrue("has no BoostAttribute", ts.hasAttribute(BoostAttribute.class));
boostAtt = ts.getAttribute(BoostAttribute.class);
}
// Maps position to the start/end offset:
final Map<Integer, Integer> posToStartOffset = new HashMap<>();
final Map<Integer, Integer> posToEndOffset = new HashMap<>();
@ -243,6 +251,7 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
if (payloadAtt != null)
payloadAtt.setPayload(new BytesRef(new byte[] {0x00, -0x21, 0x12, -0x43, 0x24}));
if (flagsAtt != null) flagsAtt.setFlags(~0); // all 1's
if (boostAtt != null) boostAtt.setBoost(-1f);
checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before
assertTrue("token " + i + " does not exist", ts.incrementToken());
@ -278,6 +287,9 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
if (flagsAtt != null) {
assertEquals("flagsAtt " + i + " term=" + termAtt, flags[i], flagsAtt.getFlags());
}
if (boostAtt != null) {
assertEquals("boostAtt " + i + " term=" + termAtt, boost[i], boostAtt.getBoost(), 0.001);
}
if (payloads != null) {
if (payloads[i] != null) {
assertEquals("payloads " + i, new BytesRef(payloads[i]), payloadAtt.getPayload());
@ -405,6 +417,7 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
if (payloadAtt != null)
payloadAtt.setPayload(new BytesRef(new byte[] {0x00, -0x21, 0x12, -0x43, 0x24}));
if (flagsAtt != null) flagsAtt.setFlags(~0); // all 1's
if (boostAtt != null) boostAtt.setBoost(-1);
checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before
@ -426,6 +439,38 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
ts.close();
}
public static void assertTokenStreamContents(
TokenStream ts,
String[] output,
int[] startOffsets,
int[] endOffsets,
String[] types,
int[] posIncrements,
int[] posLengths,
Integer finalOffset,
Integer finalPosInc,
boolean[] keywordAtts,
boolean graphOffsetsAreCorrect,
byte[][] payloads,
int[] flags)
throws IOException {
assertTokenStreamContents(
ts,
output,
startOffsets,
endOffsets,
types,
posIncrements,
posLengths,
finalOffset,
finalPosInc,
keywordAtts,
graphOffsetsAreCorrect,
payloads,
flags,
null);
}
public static void assertTokenStreamContents(
TokenStream ts,
String[] output,
@ -438,6 +483,33 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
boolean[] keywordAtts,
boolean graphOffsetsAreCorrect)
throws IOException {
assertTokenStreamContents(
ts,
output,
startOffsets,
endOffsets,
types,
posIncrements,
posLengths,
finalOffset,
keywordAtts,
graphOffsetsAreCorrect,
null);
}
public static void assertTokenStreamContents(
TokenStream ts,
String[] output,
int[] startOffsets,
int[] endOffsets,
String[] types,
int[] posIncrements,
int[] posLengths,
Integer finalOffset,
boolean[] keywordAtts,
boolean graphOffsetsAreCorrect,
float[] boost)
throws IOException {
assertTokenStreamContents(
ts,
output,
@ -451,7 +523,8 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
keywordAtts,
graphOffsetsAreCorrect,
null,
null);
null,
boost);
}
public static void assertTokenStreamContents(
@ -481,9 +554,36 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
keywordAtts,
graphOffsetsAreCorrect,
payloads,
null,
null);
}
public static void assertTokenStreamContents(
TokenStream ts,
String[] output,
int[] startOffsets,
int[] endOffsets,
String[] types,
int[] posIncrements,
int[] posLengths,
Integer finalOffset,
boolean graphOffsetsAreCorrect,
float[] boost)
throws IOException {
assertTokenStreamContents(
ts,
output,
startOffsets,
endOffsets,
types,
posIncrements,
posLengths,
finalOffset,
null,
graphOffsetsAreCorrect,
boost);
}
public static void assertTokenStreamContents(
TokenStream ts,
String[] output,
@ -505,7 +605,8 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
posLengths,
finalOffset,
null,
graphOffsetsAreCorrect);
graphOffsetsAreCorrect,
null);
}
public static void assertTokenStreamContents(
@ -522,6 +623,30 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
ts, output, startOffsets, endOffsets, types, posIncrements, posLengths, finalOffset, true);
}
public static void assertTokenStreamContents(
TokenStream ts,
String[] output,
int[] startOffsets,
int[] endOffsets,
String[] types,
int[] posIncrements,
int[] posLengths,
Integer finalOffset,
float[] boost)
throws IOException {
assertTokenStreamContents(
ts,
output,
startOffsets,
endOffsets,
types,
posIncrements,
posLengths,
finalOffset,
true,
boost);
}
public static void assertTokenStreamContents(
TokenStream ts,
String[] output,
@ -649,6 +774,21 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
int[] posIncrements,
int[] posLengths)
throws IOException {
assertAnalyzesTo(
a, input, output, startOffsets, endOffsets, types, posIncrements, posLengths, null);
}
public static void assertAnalyzesTo(
Analyzer a,
String input,
String[] output,
int[] startOffsets,
int[] endOffsets,
String[] types,
int[] posIncrements,
int[] posLengths,
float[] boost)
throws IOException {
assertTokenStreamContents(
a.tokenStream("dummy", input),
output,
@ -657,7 +797,8 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
types,
posIncrements,
posLengths,
input.length());
input.length(),
boost);
checkResetException(a, input);
checkAnalysisConsistency(random(), a, true, input);
}