mirror of https://github.com/apache/lucene.git
Introduced the Word2VecSynonymFilter (#12169)
Co-authored-by: Alessandro Benedetti <a.benedetti@sease.io>
This commit is contained in:
parent
5e0761eab5
commit
1f4f2bf509
|
@ -135,6 +135,8 @@ New Features
|
|||
crash the JVM. To disable this feature, pass the following sysprop on Java command line:
|
||||
"-Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false" (Uwe Schindler)
|
||||
|
||||
* GITHUB#12169: Introduce a new token filter to expand synonyms based on Word2Vec DL4j models. (Daniele Antuzi, Ilaria Petreti, Alessandro Benedetti)
|
||||
|
||||
Improvements
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -89,6 +89,8 @@ import org.apache.lucene.analysis.shingle.ShingleFilter;
|
|||
import org.apache.lucene.analysis.standard.StandardTokenizer;
|
||||
import org.apache.lucene.analysis.stempel.StempelStemmer;
|
||||
import org.apache.lucene.analysis.synonym.SynonymMap;
|
||||
import org.apache.lucene.analysis.synonym.word2vec.Word2VecModel;
|
||||
import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProvider;
|
||||
import org.apache.lucene.store.ByteBuffersDirectory;
|
||||
import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase;
|
||||
import org.apache.lucene.tests.analysis.MockTokenFilter;
|
||||
|
@ -99,8 +101,10 @@ import org.apache.lucene.tests.util.TestUtil;
|
|||
import org.apache.lucene.tests.util.automaton.AutomatonTestUtil;
|
||||
import org.apache.lucene.util.AttributeFactory;
|
||||
import org.apache.lucene.util.AttributeSource;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.CharsRef;
|
||||
import org.apache.lucene.util.IgnoreRandomChains;
|
||||
import org.apache.lucene.util.TermAndVector;
|
||||
import org.apache.lucene.util.Version;
|
||||
import org.apache.lucene.util.automaton.Automaton;
|
||||
import org.apache.lucene.util.automaton.CharacterRunAutomaton;
|
||||
|
@ -415,6 +419,27 @@ public class TestRandomChains extends BaseTokenStreamTestCase {
|
|||
}
|
||||
}
|
||||
});
|
||||
put(
|
||||
Word2VecSynonymProvider.class,
|
||||
random -> {
|
||||
final int numEntries = atLeast(10);
|
||||
final int vectorDimension = random.nextInt(99) + 1;
|
||||
Word2VecModel model = new Word2VecModel(numEntries, vectorDimension);
|
||||
for (int j = 0; j < numEntries; j++) {
|
||||
String s = TestUtil.randomSimpleString(random, 10, 20);
|
||||
float[] vec = new float[vectorDimension];
|
||||
for (int i = 0; i < vectorDimension; i++) {
|
||||
vec[i] = random.nextFloat();
|
||||
}
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef(s), vec));
|
||||
}
|
||||
try {
|
||||
return new Word2VecSynonymProvider(model);
|
||||
} catch (IOException e) {
|
||||
Rethrow.rethrow(e);
|
||||
return null; // unreachable code
|
||||
}
|
||||
});
|
||||
put(
|
||||
DateFormat.class,
|
||||
random -> {
|
||||
|
|
|
@ -79,6 +79,7 @@ module org.apache.lucene.analysis.common {
|
|||
exports org.apache.lucene.analysis.sr;
|
||||
exports org.apache.lucene.analysis.sv;
|
||||
exports org.apache.lucene.analysis.synonym;
|
||||
exports org.apache.lucene.analysis.synonym.word2vec;
|
||||
exports org.apache.lucene.analysis.ta;
|
||||
exports org.apache.lucene.analysis.te;
|
||||
exports org.apache.lucene.analysis.th;
|
||||
|
@ -257,6 +258,7 @@ module org.apache.lucene.analysis.common {
|
|||
org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory,
|
||||
org.apache.lucene.analysis.synonym.SynonymFilterFactory,
|
||||
org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory,
|
||||
org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory,
|
||||
org.apache.lucene.analysis.core.FlattenGraphFilterFactory,
|
||||
org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory,
|
||||
org.apache.lucene.analysis.te.TeluguStemFilterFactory,
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.analysis.synonym.word2vec;
|
||||
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Base64;
|
||||
import java.util.Locale;
|
||||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipInputStream;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.TermAndVector;
|
||||
|
||||
/**
|
||||
* Dl4jModelReader reads the file generated by the library Deeplearning4j and provide a
|
||||
* Word2VecModel with normalized vectors
|
||||
*
|
||||
* <p>Dl4j Word2Vec documentation:
|
||||
* https://deeplearning4j.konduit.ai/v/en-1.0.0-beta7/language-processing/word2vec Example to
|
||||
* generate a model using dl4j:
|
||||
* https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/embeddingsfromcorpus/word2vec/Word2VecRawTextExample.java
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class Dl4jModelReader implements Closeable {
|
||||
|
||||
private static final String MODEL_FILE_NAME_PREFIX = "syn0";
|
||||
|
||||
private final ZipInputStream word2VecModelZipFile;
|
||||
|
||||
public Dl4jModelReader(InputStream stream) {
|
||||
this.word2VecModelZipFile = new ZipInputStream(new BufferedInputStream(stream));
|
||||
}
|
||||
|
||||
public Word2VecModel read() throws IOException {
|
||||
|
||||
ZipEntry entry;
|
||||
while ((entry = word2VecModelZipFile.getNextEntry()) != null) {
|
||||
String fileName = entry.getName();
|
||||
if (fileName.startsWith(MODEL_FILE_NAME_PREFIX)) {
|
||||
BufferedReader reader =
|
||||
new BufferedReader(new InputStreamReader(word2VecModelZipFile, StandardCharsets.UTF_8));
|
||||
|
||||
String header = reader.readLine();
|
||||
String[] headerValues = header.split(" ");
|
||||
int dictionarySize = Integer.parseInt(headerValues[0]);
|
||||
int vectorDimension = Integer.parseInt(headerValues[1]);
|
||||
|
||||
Word2VecModel model = new Word2VecModel(dictionarySize, vectorDimension);
|
||||
String line = reader.readLine();
|
||||
boolean isTermB64Encoded = false;
|
||||
if (line != null) {
|
||||
String[] tokens = line.split(" ");
|
||||
isTermB64Encoded =
|
||||
tokens[0].substring(0, 3).toLowerCase(Locale.ROOT).compareTo("b64") == 0;
|
||||
model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded));
|
||||
}
|
||||
while ((line = reader.readLine()) != null) {
|
||||
String[] tokens = line.split(" ");
|
||||
model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded));
|
||||
}
|
||||
return model;
|
||||
}
|
||||
}
|
||||
throw new IllegalArgumentException(
|
||||
"Cannot read Dl4j word2vec model - '"
|
||||
+ MODEL_FILE_NAME_PREFIX
|
||||
+ "' file is missing in the zip. '"
|
||||
+ MODEL_FILE_NAME_PREFIX
|
||||
+ "' is a mandatory file containing the mapping between terms and vectors generated by the DL4j library.");
|
||||
}
|
||||
|
||||
private static TermAndVector extractTermAndVector(
|
||||
String[] tokens, int vectorDimension, boolean isTermB64Encoded) {
|
||||
BytesRef term = isTermB64Encoded ? decodeB64Term(tokens[0]) : new BytesRef((tokens[0]));
|
||||
|
||||
float[] vector = new float[tokens.length - 1];
|
||||
|
||||
if (vectorDimension != vector.length) {
|
||||
throw new RuntimeException(
|
||||
String.format(
|
||||
Locale.ROOT,
|
||||
"Word2Vec model file corrupted. "
|
||||
+ "Declared vectors of size %d but found vector of size %d for word %s (%s)",
|
||||
vectorDimension,
|
||||
vector.length,
|
||||
tokens[0],
|
||||
term.utf8ToString()));
|
||||
}
|
||||
|
||||
for (int i = 1; i < tokens.length; i++) {
|
||||
vector[i - 1] = Float.parseFloat(tokens[i]);
|
||||
}
|
||||
return new TermAndVector(term, vector);
|
||||
}
|
||||
|
||||
static BytesRef decodeB64Term(String term) {
|
||||
byte[] buffer = Base64.getDecoder().decode(term.substring(4));
|
||||
return new BytesRef(buffer, 0, buffer.length);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
word2VecModelZipFile.close();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.analysis.synonym.word2vec;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.BytesRefHash;
|
||||
import org.apache.lucene.util.TermAndVector;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
|
||||
/**
|
||||
* Word2VecModel is a class representing the parsed Word2Vec model containing the vectors for each
|
||||
* word in dictionary
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class Word2VecModel implements RandomAccessVectorValues<float[]> {
|
||||
|
||||
private final int dictionarySize;
|
||||
private final int vectorDimension;
|
||||
private final TermAndVector[] termsAndVectors;
|
||||
private final BytesRefHash word2Vec;
|
||||
private int loadedCount = 0;
|
||||
|
||||
public Word2VecModel(int dictionarySize, int vectorDimension) {
|
||||
this.dictionarySize = dictionarySize;
|
||||
this.vectorDimension = vectorDimension;
|
||||
this.termsAndVectors = new TermAndVector[dictionarySize];
|
||||
this.word2Vec = new BytesRefHash();
|
||||
}
|
||||
|
||||
private Word2VecModel(
|
||||
int dictionarySize,
|
||||
int vectorDimension,
|
||||
TermAndVector[] termsAndVectors,
|
||||
BytesRefHash word2Vec) {
|
||||
this.dictionarySize = dictionarySize;
|
||||
this.vectorDimension = vectorDimension;
|
||||
this.termsAndVectors = termsAndVectors;
|
||||
this.word2Vec = word2Vec;
|
||||
}
|
||||
|
||||
public void addTermAndVector(TermAndVector modelEntry) {
|
||||
modelEntry.normalizeVector();
|
||||
this.termsAndVectors[loadedCount++] = modelEntry;
|
||||
this.word2Vec.add(modelEntry.getTerm());
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) {
|
||||
return termsAndVectors[targetOrd].getVector();
|
||||
}
|
||||
|
||||
public float[] vectorValue(BytesRef term) {
|
||||
int termOrd = this.word2Vec.find(term);
|
||||
if (termOrd < 0) return null;
|
||||
TermAndVector entry = this.termsAndVectors[termOrd];
|
||||
return (entry == null) ? null : entry.getVector();
|
||||
}
|
||||
|
||||
public BytesRef termValue(int targetOrd) {
|
||||
return termsAndVectors[targetOrd].getTerm();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return vectorDimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return dictionarySize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||
return new Word2VecModel(
|
||||
this.dictionarySize, this.vectorDimension, this.termsAndVectors, this.word2Vec);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.analysis.synonym.word2vec;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.analysis.TokenFilter;
|
||||
import org.apache.lucene.analysis.TokenStream;
|
||||
import org.apache.lucene.analysis.synonym.SynonymGraphFilter;
|
||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
|
||||
import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute;
|
||||
import org.apache.lucene.analysis.tokenattributes.TypeAttribute;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.BytesRefBuilder;
|
||||
import org.apache.lucene.util.TermAndBoost;
|
||||
|
||||
/**
|
||||
* Applies single-token synonyms from a Word2Vec trained network to an incoming {@link TokenStream}.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class Word2VecSynonymFilter extends TokenFilter {
|
||||
|
||||
private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
|
||||
private final PositionIncrementAttribute posIncrementAtt =
|
||||
addAttribute(PositionIncrementAttribute.class);
|
||||
private final PositionLengthAttribute posLenAtt = addAttribute(PositionLengthAttribute.class);
|
||||
private final TypeAttribute typeAtt = addAttribute(TypeAttribute.class);
|
||||
|
||||
private final Word2VecSynonymProvider synonymProvider;
|
||||
private final int maxSynonymsPerTerm;
|
||||
private final float minAcceptedSimilarity;
|
||||
private final LinkedList<TermAndBoost> synonymBuffer = new LinkedList<>();
|
||||
private State lastState;
|
||||
|
||||
/**
|
||||
* Apply previously built synonymProvider to incoming tokens.
|
||||
*
|
||||
* @param input input tokenstream
|
||||
* @param synonymProvider synonym provider
|
||||
* @param maxSynonymsPerTerm maximum number of result returned by the synonym search
|
||||
* @param minAcceptedSimilarity minimal value of cosine similarity between the searched vector and
|
||||
* the retrieved ones
|
||||
*/
|
||||
public Word2VecSynonymFilter(
|
||||
TokenStream input,
|
||||
Word2VecSynonymProvider synonymProvider,
|
||||
int maxSynonymsPerTerm,
|
||||
float minAcceptedSimilarity) {
|
||||
super(input);
|
||||
this.synonymProvider = synonymProvider;
|
||||
this.maxSynonymsPerTerm = maxSynonymsPerTerm;
|
||||
this.minAcceptedSimilarity = minAcceptedSimilarity;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean incrementToken() throws IOException {
|
||||
|
||||
if (!synonymBuffer.isEmpty()) {
|
||||
TermAndBoost synonym = synonymBuffer.pollFirst();
|
||||
clearAttributes();
|
||||
restoreState(this.lastState);
|
||||
termAtt.setEmpty();
|
||||
termAtt.append(synonym.term.utf8ToString());
|
||||
typeAtt.setType(SynonymGraphFilter.TYPE_SYNONYM);
|
||||
posLenAtt.setPositionLength(1);
|
||||
posIncrementAtt.setPositionIncrement(0);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (input.incrementToken()) {
|
||||
BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
|
||||
bytesRefBuilder.copyChars(termAtt.buffer(), 0, termAtt.length());
|
||||
BytesRef term = bytesRefBuilder.get();
|
||||
List<TermAndBoost> synonyms =
|
||||
this.synonymProvider.getSynonyms(term, maxSynonymsPerTerm, minAcceptedSimilarity);
|
||||
if (synonyms.size() > 0) {
|
||||
this.lastState = captureState();
|
||||
this.synonymBuffer.addAll(synonyms);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() throws IOException {
|
||||
super.reset();
|
||||
synonymBuffer.clear();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.analysis.synonym.word2vec;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.analysis.TokenFilterFactory;
|
||||
import org.apache.lucene.analysis.TokenStream;
|
||||
import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProviderFactory.Word2VecSupportedFormats;
|
||||
import org.apache.lucene.util.ResourceLoader;
|
||||
import org.apache.lucene.util.ResourceLoaderAware;
|
||||
|
||||
/**
|
||||
* Factory for {@link Word2VecSynonymFilter}.
|
||||
*
|
||||
* @lucene.experimental
|
||||
* @lucene.spi {@value #NAME}
|
||||
*/
|
||||
public class Word2VecSynonymFilterFactory extends TokenFilterFactory
|
||||
implements ResourceLoaderAware {
|
||||
|
||||
/** SPI name */
|
||||
public static final String NAME = "Word2VecSynonym";
|
||||
|
||||
public static final int DEFAULT_MAX_SYNONYMS_PER_TERM = 5;
|
||||
public static final float DEFAULT_MIN_ACCEPTED_SIMILARITY = 0.8f;
|
||||
|
||||
private final int maxSynonymsPerTerm;
|
||||
private final float minAcceptedSimilarity;
|
||||
private final Word2VecSupportedFormats format;
|
||||
private final String word2vecModelFileName;
|
||||
|
||||
private Word2VecSynonymProvider synonymProvider;
|
||||
|
||||
public Word2VecSynonymFilterFactory(Map<String, String> args) {
|
||||
super(args);
|
||||
this.maxSynonymsPerTerm = getInt(args, "maxSynonymsPerTerm", DEFAULT_MAX_SYNONYMS_PER_TERM);
|
||||
this.minAcceptedSimilarity =
|
||||
getFloat(args, "minAcceptedSimilarity", DEFAULT_MIN_ACCEPTED_SIMILARITY);
|
||||
this.word2vecModelFileName = require(args, "model");
|
||||
|
||||
String modelFormat = get(args, "format", "dl4j").toUpperCase(Locale.ROOT);
|
||||
try {
|
||||
this.format = Word2VecSupportedFormats.valueOf(modelFormat);
|
||||
} catch (IllegalArgumentException exc) {
|
||||
throw new IllegalArgumentException("Model format '" + modelFormat + "' not supported", exc);
|
||||
}
|
||||
|
||||
if (!args.isEmpty()) {
|
||||
throw new IllegalArgumentException("Unknown parameters: " + args);
|
||||
}
|
||||
if (minAcceptedSimilarity <= 0 || minAcceptedSimilarity > 1) {
|
||||
throw new IllegalArgumentException(
|
||||
"minAcceptedSimilarity must be in the range (0, 1]. Found: " + minAcceptedSimilarity);
|
||||
}
|
||||
if (maxSynonymsPerTerm <= 0) {
|
||||
throw new IllegalArgumentException(
|
||||
"maxSynonymsPerTerm must be a positive integer greater than 0. Found: "
|
||||
+ maxSynonymsPerTerm);
|
||||
}
|
||||
}
|
||||
|
||||
/** Default ctor for compatibility with SPI */
|
||||
public Word2VecSynonymFilterFactory() {
|
||||
throw defaultCtorException();
|
||||
}
|
||||
|
||||
Word2VecSynonymProvider getSynonymProvider() {
|
||||
return this.synonymProvider;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TokenStream create(TokenStream input) {
|
||||
return synonymProvider == null
|
||||
? input
|
||||
: new Word2VecSynonymFilter(
|
||||
input, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void inform(ResourceLoader loader) throws IOException {
|
||||
this.synonymProvider =
|
||||
Word2VecSynonymProviderFactory.getSynonymProvider(loader, word2vecModelFileName, format);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.analysis.synonym.word2vec;
|
||||
|
||||
import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_BEAM_WIDTH;
|
||||
import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_MAX_CONN;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.TermAndBoost;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
|
||||
/**
|
||||
* The Word2VecSynonymProvider generates the list of sysnonyms of a term.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class Word2VecSynonymProvider {
|
||||
|
||||
private static final VectorSimilarityFunction SIMILARITY_FUNCTION =
|
||||
VectorSimilarityFunction.DOT_PRODUCT;
|
||||
private static final VectorEncoding VECTOR_ENCODING = VectorEncoding.FLOAT32;
|
||||
private final Word2VecModel word2VecModel;
|
||||
private final HnswGraph hnswGraph;
|
||||
|
||||
/**
|
||||
* Word2VecSynonymProvider constructor
|
||||
*
|
||||
* @param model containing the set of TermAndVector entries
|
||||
*/
|
||||
public Word2VecSynonymProvider(Word2VecModel model) throws IOException {
|
||||
word2VecModel = model;
|
||||
|
||||
HnswGraphBuilder<float[]> builder =
|
||||
HnswGraphBuilder.create(
|
||||
word2VecModel,
|
||||
VECTOR_ENCODING,
|
||||
SIMILARITY_FUNCTION,
|
||||
DEFAULT_MAX_CONN,
|
||||
DEFAULT_BEAM_WIDTH,
|
||||
HnswGraphBuilder.randSeed);
|
||||
this.hnswGraph = builder.build(word2VecModel.copy());
|
||||
}
|
||||
|
||||
public List<TermAndBoost> getSynonyms(
|
||||
BytesRef term, int maxSynonymsPerTerm, float minAcceptedSimilarity) throws IOException {
|
||||
|
||||
if (term == null) {
|
||||
throw new IllegalArgumentException("Term must not be null");
|
||||
}
|
||||
|
||||
LinkedList<TermAndBoost> result = new LinkedList<>();
|
||||
float[] query = word2VecModel.vectorValue(term);
|
||||
if (query != null) {
|
||||
NeighborQueue synonyms =
|
||||
HnswGraphSearcher.search(
|
||||
query,
|
||||
// The query vector is in the model. When looking for the top-k
|
||||
// it's always the nearest neighbour of itself so, we look for the top-k+1
|
||||
maxSynonymsPerTerm + 1,
|
||||
word2VecModel,
|
||||
VECTOR_ENCODING,
|
||||
SIMILARITY_FUNCTION,
|
||||
hnswGraph,
|
||||
null,
|
||||
word2VecModel.size());
|
||||
|
||||
int size = synonyms.size();
|
||||
for (int i = 0; i < size; i++) {
|
||||
float similarity = synonyms.topScore();
|
||||
int id = synonyms.pop();
|
||||
|
||||
BytesRef synonym = word2VecModel.termValue(id);
|
||||
// We remove the original query term
|
||||
if (!synonym.equals(term) && similarity >= minAcceptedSimilarity) {
|
||||
result.addFirst(new TermAndBoost(synonym, similarity));
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.analysis.synonym.word2vec;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import org.apache.lucene.util.ResourceLoader;
|
||||
|
||||
/**
|
||||
* Supply Word2Vec Word2VecSynonymProvider cache avoiding that multiple instances of
|
||||
* Word2VecSynonymFilterFactory will instantiate multiple instances of the same SynonymProvider.
|
||||
* Assumes synonymProvider implementations are thread-safe.
|
||||
*/
|
||||
public class Word2VecSynonymProviderFactory {
|
||||
|
||||
enum Word2VecSupportedFormats {
|
||||
DL4J
|
||||
}
|
||||
|
||||
private static Map<String, Word2VecSynonymProvider> word2vecSynonymProviders =
|
||||
new ConcurrentHashMap<>();
|
||||
|
||||
public static Word2VecSynonymProvider getSynonymProvider(
|
||||
ResourceLoader loader, String modelFileName, Word2VecSupportedFormats format)
|
||||
throws IOException {
|
||||
Word2VecSynonymProvider synonymProvider = word2vecSynonymProviders.get(modelFileName);
|
||||
if (synonymProvider == null) {
|
||||
try (InputStream stream = loader.openResource(modelFileName)) {
|
||||
try (Dl4jModelReader reader = getModelReader(format, stream)) {
|
||||
synonymProvider = new Word2VecSynonymProvider(reader.read());
|
||||
}
|
||||
}
|
||||
word2vecSynonymProviders.put(modelFileName, synonymProvider);
|
||||
}
|
||||
return synonymProvider;
|
||||
}
|
||||
|
||||
private static Dl4jModelReader getModelReader(
|
||||
Word2VecSupportedFormats format, InputStream stream) {
|
||||
switch (format) {
|
||||
case DL4J:
|
||||
return new Dl4jModelReader(stream);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/** Analysis components for Synonyms using Word2Vec model. */
|
||||
package org.apache.lucene.analysis.synonym.word2vec;
|
|
@ -118,6 +118,7 @@ org.apache.lucene.analysis.sv.SwedishLightStemFilterFactory
|
|||
org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory
|
||||
org.apache.lucene.analysis.synonym.SynonymFilterFactory
|
||||
org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory
|
||||
org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory
|
||||
org.apache.lucene.analysis.core.FlattenGraphFilterFactory
|
||||
org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory
|
||||
org.apache.lucene.analysis.te.TeluguStemFilterFactory
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.analysis.synonym.word2vec;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Base64;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.junit.Test;
|
||||
|
||||
public class TestDl4jModelReader extends LuceneTestCase {
|
||||
|
||||
private static final String MODEL_FILE = "word2vec-model.zip";
|
||||
private static final String MODEL_EMPTY_FILE = "word2vec-empty-model.zip";
|
||||
private static final String CORRUPTED_VECTOR_DIMENSION_MODEL_FILE =
|
||||
"word2vec-corrupted-vector-dimension-model.zip";
|
||||
|
||||
InputStream stream = TestDl4jModelReader.class.getResourceAsStream(MODEL_FILE);
|
||||
Dl4jModelReader unit = new Dl4jModelReader(stream);
|
||||
|
||||
@Test
|
||||
public void read_zipFileWithMetadata_shouldReturnDictionarySize() throws Exception {
|
||||
Word2VecModel model = unit.read();
|
||||
long expectedDictionarySize = 235;
|
||||
assertEquals(expectedDictionarySize, model.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void read_zipFileWithMetadata_shouldReturnVectorLength() throws Exception {
|
||||
Word2VecModel model = unit.read();
|
||||
int expectedVectorDimension = 100;
|
||||
assertEquals(expectedVectorDimension, model.dimension());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void read_zipFile_shouldReturnDecodedTerm() throws Exception {
|
||||
Word2VecModel model = unit.read();
|
||||
BytesRef expectedDecodedFirstTerm = new BytesRef("it");
|
||||
assertEquals(expectedDecodedFirstTerm, model.termValue(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void decodeTerm_encodedTerm_shouldReturnDecodedTerm() throws Exception {
|
||||
byte[] originalInput = "lucene".getBytes(StandardCharsets.UTF_8);
|
||||
String B64encodedLuceneTerm = Base64.getEncoder().encodeToString(originalInput);
|
||||
String word2vecEncodedLuceneTerm = "B64:" + B64encodedLuceneTerm;
|
||||
assertEquals(new BytesRef("lucene"), Dl4jModelReader.decodeB64Term(word2vecEncodedLuceneTerm));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void read_EmptyZipFile_shouldThrowException() throws Exception {
|
||||
try (InputStream stream = TestDl4jModelReader.class.getResourceAsStream(MODEL_EMPTY_FILE)) {
|
||||
Dl4jModelReader unit = new Dl4jModelReader(stream);
|
||||
expectThrows(IllegalArgumentException.class, unit::read);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void read_corruptedVectorDimensionModelFile_shouldThrowException() throws Exception {
|
||||
try (InputStream stream =
|
||||
TestDl4jModelReader.class.getResourceAsStream(CORRUPTED_VECTOR_DIMENSION_MODEL_FILE)) {
|
||||
Dl4jModelReader unit = new Dl4jModelReader(stream);
|
||||
expectThrows(RuntimeException.class, unit::read);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,152 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.analysis.synonym.word2vec;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.Tokenizer;
|
||||
import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase;
|
||||
import org.apache.lucene.tests.analysis.MockTokenizer;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.TermAndVector;
|
||||
import org.junit.Test;
|
||||
|
||||
public class TestWord2VecSynonymFilter extends BaseTokenStreamTestCase {
|
||||
|
||||
@Test
|
||||
public void synonymExpansion_oneCandidate_shouldBeExpandedWithinThreshold() throws Exception {
|
||||
int maxSynonymPerTerm = 10;
|
||||
float minAcceptedSimilarity = 0.9f;
|
||||
Word2VecModel model = new Word2VecModel(6, 2);
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10}));
|
||||
|
||||
Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
|
||||
|
||||
Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity);
|
||||
assertAnalyzesTo(
|
||||
a,
|
||||
"pre a post", // input
|
||||
new String[] {"pre", "a", "d", "e", "c", "b", "post"}, // output
|
||||
new int[] {0, 4, 4, 4, 4, 4, 6}, // start offset
|
||||
new int[] {3, 5, 5, 5, 5, 5, 10}, // end offset
|
||||
new String[] {"word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word"}, // types
|
||||
new int[] {1, 1, 0, 0, 0, 0, 1}, // posIncrements
|
||||
new int[] {1, 1, 1, 1, 1, 1, 1}); // posLenghts
|
||||
a.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void synonymExpansion_oneCandidate_shouldBeExpandedWithTopKSynonyms() throws Exception {
|
||||
int maxSynonymPerTerm = 2;
|
||||
float minAcceptedSimilarity = 0.9f;
|
||||
Word2VecModel model = new Word2VecModel(5, 2);
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
|
||||
|
||||
Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
|
||||
|
||||
Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity);
|
||||
assertAnalyzesTo(
|
||||
a,
|
||||
"pre a post", // input
|
||||
new String[] {"pre", "a", "d", "e", "post"}, // output
|
||||
new int[] {0, 4, 4, 4, 6}, // start offset
|
||||
new int[] {3, 5, 5, 5, 10}, // end offset
|
||||
new String[] {"word", "word", "SYNONYM", "SYNONYM", "word"}, // types
|
||||
new int[] {1, 1, 0, 0, 1}, // posIncrements
|
||||
new int[] {1, 1, 1, 1, 1}); // posLenghts
|
||||
a.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void synonymExpansion_twoCandidates_shouldBothBeExpanded() throws Exception {
|
||||
Word2VecModel model = new Word2VecModel(8, 2);
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {1, 10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("post"), new float[] {-10, -11}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("after"), new float[] {-8, -10}));
|
||||
|
||||
Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
|
||||
|
||||
Analyzer a = getAnalyzer(synonymProvider, 10, 0.9f);
|
||||
assertAnalyzesTo(
|
||||
a,
|
||||
"pre a post", // input
|
||||
new String[] {"pre", "a", "d", "e", "c", "b", "post", "after"}, // output
|
||||
new int[] {0, 4, 4, 4, 4, 4, 6, 6}, // start offset
|
||||
new int[] {3, 5, 5, 5, 5, 5, 10, 10}, // end offset
|
||||
new String[] { // types
|
||||
"word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word", "SYNONYM"
|
||||
},
|
||||
new int[] {1, 1, 0, 0, 0, 0, 1, 0}, // posIncrements
|
||||
new int[] {1, 1, 1, 1, 1, 1, 1, 1}); // posLengths
|
||||
a.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void synonymExpansion_forMinAcceptedSimilarity_shouldExpandToNoneSynonyms()
|
||||
throws Exception {
|
||||
Word2VecModel model = new Word2VecModel(4, 2);
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {-10, -8}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {-9, -10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, -10}));
|
||||
|
||||
Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
|
||||
|
||||
Analyzer a = getAnalyzer(synonymProvider, 10, 0.8f);
|
||||
assertAnalyzesTo(
|
||||
a,
|
||||
"pre a post", // input
|
||||
new String[] {"pre", "a", "post"}, // output
|
||||
new int[] {0, 4, 6}, // start offset
|
||||
new int[] {3, 5, 10}, // end offset
|
||||
new String[] {"word", "word", "word"}, // types
|
||||
new int[] {1, 1, 1}, // posIncrements
|
||||
new int[] {1, 1, 1}); // posLengths
|
||||
a.close();
|
||||
}
|
||||
|
||||
private Analyzer getAnalyzer(
|
||||
Word2VecSynonymProvider synonymProvider,
|
||||
int maxSynonymsPerTerm,
|
||||
float minAcceptedSimilarity) {
|
||||
return new Analyzer() {
|
||||
@Override
|
||||
protected TokenStreamComponents createComponents(String fieldName) {
|
||||
Tokenizer tokenizer = new MockTokenizer(MockTokenizer.WHITESPACE, false);
|
||||
// Make a local variable so testRandomHuge doesn't share it across threads!
|
||||
Word2VecSynonymFilter synFilter =
|
||||
new Word2VecSynonymFilter(
|
||||
tokenizer, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity);
|
||||
return new TokenStreamComponents(tokenizer, synFilter);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,159 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.analysis.synonym.word2vec;
|
||||
|
||||
import org.apache.lucene.tests.analysis.BaseTokenStreamFactoryTestCase;
|
||||
import org.apache.lucene.util.ClasspathResourceLoader;
|
||||
import org.apache.lucene.util.ResourceLoader;
|
||||
import org.junit.Test;
|
||||
|
||||
public class TestWord2VecSynonymFilterFactory extends BaseTokenStreamFactoryTestCase {
|
||||
|
||||
public static final String FACTORY_NAME = "Word2VecSynonym";
|
||||
private static final String WORD2VEC_MODEL_FILE = "word2vec-model.zip";
|
||||
|
||||
@Test
|
||||
public void testInform() throws Exception {
|
||||
ResourceLoader loader = new ClasspathResourceLoader(getClass());
|
||||
assertTrue("loader is null and it shouldn't be", loader != null);
|
||||
Word2VecSynonymFilterFactory factory =
|
||||
(Word2VecSynonymFilterFactory)
|
||||
tokenFilterFactory(
|
||||
FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "minAcceptedSimilarity", "0.7");
|
||||
|
||||
Word2VecSynonymProvider synonymProvider = factory.getSynonymProvider();
|
||||
assertNotEquals(null, synonymProvider);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void missingRequiredArgument_shouldThrowException() throws Exception {
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> {
|
||||
tokenFilterFactory(
|
||||
FACTORY_NAME,
|
||||
"format",
|
||||
"dl4j",
|
||||
"minAcceptedSimilarity",
|
||||
"0.7",
|
||||
"maxSynonymsPerTerm",
|
||||
"10");
|
||||
});
|
||||
assertTrue(expected.getMessage().contains("Configuration Error: missing parameter 'model'"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void unsupportedModelFormat_shouldThrowException() throws Exception {
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> {
|
||||
tokenFilterFactory(
|
||||
FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "format", "bogusValue");
|
||||
});
|
||||
assertTrue(expected.getMessage().contains("Model format 'BOGUSVALUE' not supported"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void bogusArgument_shouldThrowException() throws Exception {
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> {
|
||||
tokenFilterFactory(
|
||||
FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "bogusArg", "bogusValue");
|
||||
});
|
||||
assertTrue(expected.getMessage().contains("Unknown parameters"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void illegalArguments_shouldThrowException() throws Exception {
|
||||
IllegalArgumentException expected =
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> {
|
||||
tokenFilterFactory(
|
||||
FACTORY_NAME,
|
||||
"model",
|
||||
WORD2VEC_MODEL_FILE,
|
||||
"minAcceptedSimilarity",
|
||||
"2",
|
||||
"maxSynonymsPerTerm",
|
||||
"10");
|
||||
});
|
||||
assertTrue(
|
||||
expected
|
||||
.getMessage()
|
||||
.contains("minAcceptedSimilarity must be in the range (0, 1]. Found: 2"));
|
||||
|
||||
expected =
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> {
|
||||
tokenFilterFactory(
|
||||
FACTORY_NAME,
|
||||
"model",
|
||||
WORD2VEC_MODEL_FILE,
|
||||
"minAcceptedSimilarity",
|
||||
"0",
|
||||
"maxSynonymsPerTerm",
|
||||
"10");
|
||||
});
|
||||
assertTrue(
|
||||
expected
|
||||
.getMessage()
|
||||
.contains("minAcceptedSimilarity must be in the range (0, 1]. Found: 0"));
|
||||
|
||||
expected =
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> {
|
||||
tokenFilterFactory(
|
||||
FACTORY_NAME,
|
||||
"model",
|
||||
WORD2VEC_MODEL_FILE,
|
||||
"minAcceptedSimilarity",
|
||||
"0.7",
|
||||
"maxSynonymsPerTerm",
|
||||
"-1");
|
||||
});
|
||||
assertTrue(
|
||||
expected
|
||||
.getMessage()
|
||||
.contains("maxSynonymsPerTerm must be a positive integer greater than 0. Found: -1"));
|
||||
|
||||
expected =
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> {
|
||||
tokenFilterFactory(
|
||||
FACTORY_NAME,
|
||||
"model",
|
||||
WORD2VEC_MODEL_FILE,
|
||||
"minAcceptedSimilarity",
|
||||
"0.7",
|
||||
"maxSynonymsPerTerm",
|
||||
"0");
|
||||
});
|
||||
assertTrue(
|
||||
expected
|
||||
.getMessage()
|
||||
.contains("maxSynonymsPerTerm must be a positive integer greater than 0. Found: 0"));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.analysis.synonym.word2vec;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.TermAndBoost;
|
||||
import org.apache.lucene.util.TermAndVector;
|
||||
import org.junit.Test;
|
||||
|
||||
public class TestWord2VecSynonymProvider extends LuceneTestCase {
|
||||
|
||||
private static final int MAX_SYNONYMS_PER_TERM = 10;
|
||||
private static final float MIN_ACCEPTED_SIMILARITY = 0.85f;
|
||||
|
||||
private final Word2VecSynonymProvider unit;
|
||||
|
||||
public TestWord2VecSynonymProvider() throws IOException {
|
||||
Word2VecModel model = new Word2VecModel(2, 3);
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {0.24f, 0.78f, 0.28f}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {0.44f, 0.01f, 0.81f}));
|
||||
unit = new Word2VecSynonymProvider(model);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getSynonyms_nullToken_shouldThrowException() {
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> unit.getSynonyms(null, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getSynonyms_shouldReturnSynonymsBasedOnMinAcceptedSimilarity() throws Exception {
|
||||
Word2VecModel model = new Word2VecModel(6, 2);
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10}));
|
||||
|
||||
Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model);
|
||||
|
||||
BytesRef inputTerm = new BytesRef("a");
|
||||
String[] expectedSynonyms = {"d", "e", "c", "b"};
|
||||
List<TermAndBoost> actualSynonymsResults =
|
||||
unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY);
|
||||
|
||||
assertEquals(4, actualSynonymsResults.size());
|
||||
for (int i = 0; i < expectedSynonyms.length; i++) {
|
||||
assertEquals(new BytesRef(expectedSynonyms[i]), actualSynonymsResults.get(i).term);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getSynonyms_shouldReturnSynonymsBoost() throws Exception {
|
||||
Word2VecModel model = new Word2VecModel(3, 2);
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {1, 1}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {99, 101}));
|
||||
|
||||
Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model);
|
||||
|
||||
BytesRef inputTerm = new BytesRef("a");
|
||||
List<TermAndBoost> actualSynonymsResults =
|
||||
unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY);
|
||||
|
||||
BytesRef expectedFirstSynonymTerm = new BytesRef("b");
|
||||
double expectedFirstSynonymBoost = 1.0;
|
||||
assertEquals(expectedFirstSynonymTerm, actualSynonymsResults.get(0).term);
|
||||
assertEquals(expectedFirstSynonymBoost, actualSynonymsResults.get(0).boost, 0.001f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void noSynonymsWithinAcceptedSimilarity_shouldReturnNoSynonyms() throws Exception {
|
||||
Word2VecModel model = new Word2VecModel(4, 2);
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {-10, -8}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {-9, -10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {6, -6}));
|
||||
|
||||
Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model);
|
||||
|
||||
BytesRef inputTerm = newBytesRef("a");
|
||||
List<TermAndBoost> actualSynonymsResults =
|
||||
unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY);
|
||||
assertEquals(0, actualSynonymsResults.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testModel_shouldReturnNormalizedVectors() {
|
||||
Word2VecModel model = new Word2VecModel(4, 2);
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
|
||||
model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10}));
|
||||
|
||||
float[] vectorIdA = model.vectorValue(new BytesRef("a"));
|
||||
float[] vectorIdF = model.vectorValue(new BytesRef("f"));
|
||||
assertArrayEquals(new float[] {0.70710f, 0.70710f}, vectorIdA, 0.001f);
|
||||
assertArrayEquals(new float[] {-0.0995f, 0.99503f}, vectorIdF, 0.001f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void normalizedVector_shouldReturnModule1() {
|
||||
TermAndVector synonymTerm = new TermAndVector(new BytesRef("a"), new float[] {10, 10});
|
||||
synonymTerm.normalizeVector();
|
||||
float[] vector = synonymTerm.getVector();
|
||||
float len = 0;
|
||||
for (int i = 0; i < vector.length; i++) {
|
||||
len += vector[i] * vector[i];
|
||||
}
|
||||
assertEquals(1, Math.sqrt(len), 0.0001f);
|
||||
}
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -62,20 +62,6 @@ public class QueryBuilder {
|
|||
protected boolean enableGraphQueries = true;
|
||||
protected boolean autoGenerateMultiTermSynonymsPhraseQuery = false;
|
||||
|
||||
/** Wraps a term and boost */
|
||||
public static class TermAndBoost {
|
||||
/** the term */
|
||||
public final BytesRef term;
|
||||
/** the boost */
|
||||
public final float boost;
|
||||
|
||||
/** Creates a new TermAndBoost */
|
||||
public TermAndBoost(BytesRef term, float boost) {
|
||||
this.term = BytesRef.deepCopyOf(term);
|
||||
this.boost = boost;
|
||||
}
|
||||
}
|
||||
|
||||
/** Creates a new QueryBuilder using the given analyzer. */
|
||||
public QueryBuilder(Analyzer analyzer) {
|
||||
this.analyzer = analyzer;
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.util;
|
||||
|
||||
/** Wraps a term and boost */
|
||||
public class TermAndBoost {
|
||||
/** the term */
|
||||
public final BytesRef term;
|
||||
/** the boost */
|
||||
public final float boost;
|
||||
|
||||
/** Creates a new TermAndBoost */
|
||||
public TermAndBoost(BytesRef term, float boost) {
|
||||
this.term = BytesRef.deepCopyOf(term);
|
||||
this.boost = boost;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util;
|
||||
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
* Word2Vec unit composed by a term with the associated vector
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class TermAndVector {
|
||||
|
||||
private final BytesRef term;
|
||||
private final float[] vector;
|
||||
|
||||
public TermAndVector(BytesRef term, float[] vector) {
|
||||
this.term = term;
|
||||
this.vector = vector;
|
||||
}
|
||||
|
||||
public BytesRef getTerm() {
|
||||
return this.term;
|
||||
}
|
||||
|
||||
public float[] getVector() {
|
||||
return this.vector;
|
||||
}
|
||||
|
||||
public int size() {
|
||||
return vector.length;
|
||||
}
|
||||
|
||||
public void normalizeVector() {
|
||||
float vectorLength = 0;
|
||||
for (int i = 0; i < vector.length; i++) {
|
||||
vectorLength += vector[i] * vector[i];
|
||||
}
|
||||
vectorLength = (float) Math.sqrt(vectorLength);
|
||||
for (int i = 0; i < vector.length; i++) {
|
||||
vector[i] /= vectorLength;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
StringBuilder builder = new StringBuilder(this.term.utf8ToString());
|
||||
builder.append(" [");
|
||||
if (vector.length > 0) {
|
||||
for (int i = 0; i < vector.length - 1; i++) {
|
||||
builder.append(String.format(Locale.ROOT, "%.3f,", vector[i]));
|
||||
}
|
||||
builder.append(String.format(Locale.ROOT, "%.3f]", vector[vector.length - 1]));
|
||||
}
|
||||
return builder.toString();
|
||||
}
|
||||
}
|
|
@ -41,8 +41,17 @@ import org.apache.lucene.util.InfoStream;
|
|||
*/
|
||||
public final class HnswGraphBuilder<T> {
|
||||
|
||||
/** Default number of maximum connections per node */
|
||||
public static final int DEFAULT_MAX_CONN = 16;
|
||||
|
||||
/**
|
||||
* Default number of the size of the queue maintained while searching during a graph construction.
|
||||
*/
|
||||
public static final int DEFAULT_BEAM_WIDTH = 100;
|
||||
|
||||
/** Default random seed for level generation * */
|
||||
private static final long DEFAULT_RAND_SEED = 42;
|
||||
|
||||
/** A name for the HNSW component for the info-stream * */
|
||||
public static final String HNSW_COMPONENT = "HNSW";
|
||||
|
||||
|
|
|
@ -54,6 +54,7 @@ import org.apache.lucene.util.IOUtils;
|
|||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -62,7 +63,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
|
||||
private static final String KNN_GRAPH_FIELD = "vector";
|
||||
|
||||
private static int M = Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN;
|
||||
private static int M = HnswGraphBuilder.DEFAULT_MAX_CONN;
|
||||
|
||||
private Codec codec;
|
||||
private Codec float32Codec;
|
||||
|
@ -80,7 +81,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
new Lucene95Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
|
||||
return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -92,7 +93,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
new Lucene95Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
|
||||
return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -103,7 +104,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
new Lucene95Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
|
||||
return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -115,7 +116,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
|
||||
@After
|
||||
public void cleanup() {
|
||||
M = Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN;
|
||||
M = HnswGraphBuilder.DEFAULT_MAX_CONN;
|
||||
}
|
||||
|
||||
/** Basic test of creating documents in a graph */
|
||||
|
|
|
@ -55,6 +55,7 @@ import org.apache.lucene.document.FieldType;
|
|||
import org.apache.lucene.document.TextField;
|
||||
import org.apache.lucene.index.IndexOptions;
|
||||
import org.apache.lucene.index.IndexableFieldType;
|
||||
import org.apache.lucene.search.BoostAttribute;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
|
@ -154,7 +155,8 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
|
|||
boolean[] keywordAtts,
|
||||
boolean graphOffsetsAreCorrect,
|
||||
byte[][] payloads,
|
||||
int[] flags)
|
||||
int[] flags,
|
||||
float[] boost)
|
||||
throws IOException {
|
||||
assertNotNull(output);
|
||||
CheckClearAttributesAttribute checkClearAtt =
|
||||
|
@ -221,6 +223,12 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
|
|||
flagsAtt = ts.getAttribute(FlagsAttribute.class);
|
||||
}
|
||||
|
||||
BoostAttribute boostAtt = null;
|
||||
if (boost != null) {
|
||||
assertTrue("has no BoostAttribute", ts.hasAttribute(BoostAttribute.class));
|
||||
boostAtt = ts.getAttribute(BoostAttribute.class);
|
||||
}
|
||||
|
||||
// Maps position to the start/end offset:
|
||||
final Map<Integer, Integer> posToStartOffset = new HashMap<>();
|
||||
final Map<Integer, Integer> posToEndOffset = new HashMap<>();
|
||||
|
@ -243,6 +251,7 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
|
|||
if (payloadAtt != null)
|
||||
payloadAtt.setPayload(new BytesRef(new byte[] {0x00, -0x21, 0x12, -0x43, 0x24}));
|
||||
if (flagsAtt != null) flagsAtt.setFlags(~0); // all 1's
|
||||
if (boostAtt != null) boostAtt.setBoost(-1f);
|
||||
|
||||
checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before
|
||||
assertTrue("token " + i + " does not exist", ts.incrementToken());
|
||||
|
@ -278,6 +287,9 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
|
|||
if (flagsAtt != null) {
|
||||
assertEquals("flagsAtt " + i + " term=" + termAtt, flags[i], flagsAtt.getFlags());
|
||||
}
|
||||
if (boostAtt != null) {
|
||||
assertEquals("boostAtt " + i + " term=" + termAtt, boost[i], boostAtt.getBoost(), 0.001);
|
||||
}
|
||||
if (payloads != null) {
|
||||
if (payloads[i] != null) {
|
||||
assertEquals("payloads " + i, new BytesRef(payloads[i]), payloadAtt.getPayload());
|
||||
|
@ -405,6 +417,7 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
|
|||
if (payloadAtt != null)
|
||||
payloadAtt.setPayload(new BytesRef(new byte[] {0x00, -0x21, 0x12, -0x43, 0x24}));
|
||||
if (flagsAtt != null) flagsAtt.setFlags(~0); // all 1's
|
||||
if (boostAtt != null) boostAtt.setBoost(-1);
|
||||
|
||||
checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before
|
||||
|
||||
|
@ -426,6 +439,38 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
|
|||
ts.close();
|
||||
}
|
||||
|
||||
public static void assertTokenStreamContents(
|
||||
TokenStream ts,
|
||||
String[] output,
|
||||
int[] startOffsets,
|
||||
int[] endOffsets,
|
||||
String[] types,
|
||||
int[] posIncrements,
|
||||
int[] posLengths,
|
||||
Integer finalOffset,
|
||||
Integer finalPosInc,
|
||||
boolean[] keywordAtts,
|
||||
boolean graphOffsetsAreCorrect,
|
||||
byte[][] payloads,
|
||||
int[] flags)
|
||||
throws IOException {
|
||||
assertTokenStreamContents(
|
||||
ts,
|
||||
output,
|
||||
startOffsets,
|
||||
endOffsets,
|
||||
types,
|
||||
posIncrements,
|
||||
posLengths,
|
||||
finalOffset,
|
||||
finalPosInc,
|
||||
keywordAtts,
|
||||
graphOffsetsAreCorrect,
|
||||
payloads,
|
||||
flags,
|
||||
null);
|
||||
}
|
||||
|
||||
public static void assertTokenStreamContents(
|
||||
TokenStream ts,
|
||||
String[] output,
|
||||
|
@ -438,6 +483,33 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
|
|||
boolean[] keywordAtts,
|
||||
boolean graphOffsetsAreCorrect)
|
||||
throws IOException {
|
||||
assertTokenStreamContents(
|
||||
ts,
|
||||
output,
|
||||
startOffsets,
|
||||
endOffsets,
|
||||
types,
|
||||
posIncrements,
|
||||
posLengths,
|
||||
finalOffset,
|
||||
keywordAtts,
|
||||
graphOffsetsAreCorrect,
|
||||
null);
|
||||
}
|
||||
|
||||
public static void assertTokenStreamContents(
|
||||
TokenStream ts,
|
||||
String[] output,
|
||||
int[] startOffsets,
|
||||
int[] endOffsets,
|
||||
String[] types,
|
||||
int[] posIncrements,
|
||||
int[] posLengths,
|
||||
Integer finalOffset,
|
||||
boolean[] keywordAtts,
|
||||
boolean graphOffsetsAreCorrect,
|
||||
float[] boost)
|
||||
throws IOException {
|
||||
assertTokenStreamContents(
|
||||
ts,
|
||||
output,
|
||||
|
@ -451,7 +523,8 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
|
|||
keywordAtts,
|
||||
graphOffsetsAreCorrect,
|
||||
null,
|
||||
null);
|
||||
null,
|
||||
boost);
|
||||
}
|
||||
|
||||
public static void assertTokenStreamContents(
|
||||
|
@ -481,9 +554,36 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
|
|||
keywordAtts,
|
||||
graphOffsetsAreCorrect,
|
||||
payloads,
|
||||
null,
|
||||
null);
|
||||
}
|
||||
|
||||
public static void assertTokenStreamContents(
|
||||
TokenStream ts,
|
||||
String[] output,
|
||||
int[] startOffsets,
|
||||
int[] endOffsets,
|
||||
String[] types,
|
||||
int[] posIncrements,
|
||||
int[] posLengths,
|
||||
Integer finalOffset,
|
||||
boolean graphOffsetsAreCorrect,
|
||||
float[] boost)
|
||||
throws IOException {
|
||||
assertTokenStreamContents(
|
||||
ts,
|
||||
output,
|
||||
startOffsets,
|
||||
endOffsets,
|
||||
types,
|
||||
posIncrements,
|
||||
posLengths,
|
||||
finalOffset,
|
||||
null,
|
||||
graphOffsetsAreCorrect,
|
||||
boost);
|
||||
}
|
||||
|
||||
public static void assertTokenStreamContents(
|
||||
TokenStream ts,
|
||||
String[] output,
|
||||
|
@ -505,7 +605,8 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
|
|||
posLengths,
|
||||
finalOffset,
|
||||
null,
|
||||
graphOffsetsAreCorrect);
|
||||
graphOffsetsAreCorrect,
|
||||
null);
|
||||
}
|
||||
|
||||
public static void assertTokenStreamContents(
|
||||
|
@ -522,6 +623,30 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
|
|||
ts, output, startOffsets, endOffsets, types, posIncrements, posLengths, finalOffset, true);
|
||||
}
|
||||
|
||||
public static void assertTokenStreamContents(
|
||||
TokenStream ts,
|
||||
String[] output,
|
||||
int[] startOffsets,
|
||||
int[] endOffsets,
|
||||
String[] types,
|
||||
int[] posIncrements,
|
||||
int[] posLengths,
|
||||
Integer finalOffset,
|
||||
float[] boost)
|
||||
throws IOException {
|
||||
assertTokenStreamContents(
|
||||
ts,
|
||||
output,
|
||||
startOffsets,
|
||||
endOffsets,
|
||||
types,
|
||||
posIncrements,
|
||||
posLengths,
|
||||
finalOffset,
|
||||
true,
|
||||
boost);
|
||||
}
|
||||
|
||||
public static void assertTokenStreamContents(
|
||||
TokenStream ts,
|
||||
String[] output,
|
||||
|
@ -649,6 +774,21 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
|
|||
int[] posIncrements,
|
||||
int[] posLengths)
|
||||
throws IOException {
|
||||
assertAnalyzesTo(
|
||||
a, input, output, startOffsets, endOffsets, types, posIncrements, posLengths, null);
|
||||
}
|
||||
|
||||
public static void assertAnalyzesTo(
|
||||
Analyzer a,
|
||||
String input,
|
||||
String[] output,
|
||||
int[] startOffsets,
|
||||
int[] endOffsets,
|
||||
String[] types,
|
||||
int[] posIncrements,
|
||||
int[] posLengths,
|
||||
float[] boost)
|
||||
throws IOException {
|
||||
assertTokenStreamContents(
|
||||
a.tokenStream("dummy", input),
|
||||
output,
|
||||
|
@ -657,7 +797,8 @@ public abstract class BaseTokenStreamTestCase extends LuceneTestCase {
|
|||
types,
|
||||
posIncrements,
|
||||
posLengths,
|
||||
input.length());
|
||||
input.length(),
|
||||
boost);
|
||||
checkResetException(a, input);
|
||||
checkAnalysisConsistency(random(), a, true, input);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue