From 92e460389dc9b0af83c445cb029e3a51799a37dc Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Thu, 29 Jun 2017 10:01:49 +0200 Subject: [PATCH] LUCENE-7838 - removed dep from sandbox, created a minimal FLT version specific for knn classification --- .../lucene/classification/classification.iml | 1 - lucene/classification/build.xml | 6 +- .../KNearestFuzzyClassifier.java | 29 +- .../utils/NearestFuzzyQuery.java | 333 ++++++++++++++++++ .../KNearestFuzzyClassifierTest.java | 2 +- 5 files changed, 348 insertions(+), 23 deletions(-) create mode 100644 lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java diff --git a/dev-tools/idea/lucene/classification/classification.iml b/dev-tools/idea/lucene/classification/classification.iml index 44af1e47e47..25810edc899 100644 --- a/dev-tools/idea/lucene/classification/classification.iml +++ b/dev-tools/idea/lucene/classification/classification.iml @@ -19,6 +19,5 @@ - diff --git a/lucene/classification/build.xml b/lucene/classification/build.xml index af7d2b18d72..43bcb4b5a1a 100644 --- a/lucene/classification/build.xml +++ b/lucene/classification/build.xml @@ -28,7 +28,6 @@ - @@ -38,18 +37,17 @@ - + - - diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestFuzzyClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/KNearestFuzzyClassifier.java index 7bbdbab2cd7..cbd241b4bb4 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/KNearestFuzzyClassifier.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestFuzzyClassifier.java @@ -25,11 +25,11 @@ import java.util.List; import java.util.Map; import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.classification.utils.NearestFuzzyQuery; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.Term; -import org.apache.lucene.sandbox.queries.FuzzyLikeThisQuery; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; @@ -42,7 +42,7 @@ import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.util.BytesRef; /** - * A k-Nearest Neighbor classifier based on {@link FuzzyLikeThisQuery}. + * A k-Nearest Neighbor classifier based on {@link NearestFuzzyQuery}. * * @lucene.experimental */ @@ -51,27 +51,27 @@ public class KNearestFuzzyClassifier implements Classifier { /** * the name of the fields used as the input text */ - protected final String[] textFieldNames; + private final String[] textFieldNames; /** * the name of the field used as the output text */ - protected final String classFieldName; + private final String classFieldName; /** * an {@link IndexSearcher} used to perform queries */ - protected final IndexSearcher indexSearcher; + private final IndexSearcher indexSearcher; /** * the no. of docs to compare in order to find the nearest neighbor to the input text */ - protected final int k; + private final int k; /** * a {@link Query} used to filter the documents that should be used from this classifier's underlying {@link LeafReader} */ - protected final Query query; + private final Query query; private final Analyzer analyzer; /** @@ -145,11 +145,11 @@ public class KNearestFuzzyClassifier implements Classifier { private TopDocs knnSearch(String text) throws IOException { BooleanQuery.Builder bq = new BooleanQuery.Builder(); - FuzzyLikeThisQuery fuzzyLikeThisQuery = new FuzzyLikeThisQuery(300, analyzer); + NearestFuzzyQuery nearestFuzzyQuery = new NearestFuzzyQuery(analyzer); for (String fieldName : textFieldNames) { - fuzzyLikeThisQuery.addTerms(text, fieldName, 1f, 2); // TODO: make this parameters configurable + nearestFuzzyQuery.addTerms(text, fieldName); } - bq.add(fuzzyLikeThisQuery, BooleanClause.Occur.MUST); + bq.add(nearestFuzzyQuery, BooleanClause.Occur.MUST); Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*")); bq.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST)); if (query != null) { @@ -165,7 +165,7 @@ public class KNearestFuzzyClassifier implements Classifier { * @return a {@link List} of {@link ClassificationResult}, one for each existing class * @throws IOException if it's not possible to get the stored value of class field */ - protected List> buildListFromTopDocs(TopDocs topDocs) throws IOException { + private List> buildListFromTopDocs(TopDocs topDocs) throws IOException { Map classCounts = new HashMap<>(); Map classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs float maxScore = topDocs.getMaxScore(); @@ -174,12 +174,7 @@ public class KNearestFuzzyClassifier implements Classifier { if (storableField != null) { BytesRef cl = new BytesRef(storableField.stringValue()); //update count - Integer count = classCounts.get(cl); - if (count != null) { - classCounts.put(cl, count + 1); - } else { - classCounts.put(cl, 1); - } + classCounts.merge(cl, 1, (a, b) -> a + b); //update boost, the boost is based on the best score Double totalBoost = classBoosts.get(cl); double singleBoost = scoreDoc.score / maxScore; diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java new file mode 100644 index 00000000000..d4a26341560 --- /dev/null +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.classification.utils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Objects; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.MultiFields; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.TermContext; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.BoostAttribute; +import org.apache.lucene.search.BoostQuery; +import org.apache.lucene.search.FuzzyTermsEnum; +import org.apache.lucene.search.MaxNonCompetitiveBoostAttribute; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.util.AttributeSource; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.PriorityQueue; +import org.apache.lucene.util.automaton.LevenshteinAutomata; + +/** + * Simplification of FuzzyLikeThisQuery, to be used in the context of KNN classification. + */ +public class NearestFuzzyQuery extends Query { + + private final ArrayList fieldVals = new ArrayList<>(); + private final Analyzer analyzer; + + // fixed parameters + private static final int MAX_VARIANTS_PER_TERM = 50; + private static final float MIN_SIMILARITY = 1f; + private static final int PREFIX_LENGTH = 2; + private static final int MAX_NUM_TERMS = 300; + + /** + * Default constructor + * + * @param analyzer the analyzer used to proecss the query text + */ + public NearestFuzzyQuery(Analyzer analyzer) { + this.analyzer = analyzer; + } + + static class FieldVals { + final String queryString; + final String fieldName; + final int maxEdits; + final int prefixLength; + + FieldVals(String name, int maxEdits, String queryString) { + this.fieldName = name; + this.maxEdits = maxEdits; + this.queryString = queryString; + this.prefixLength = NearestFuzzyQuery.PREFIX_LENGTH; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + + ((fieldName == null) ? 0 : fieldName.hashCode()); + result = prime * result + maxEdits; + result = prime * result + prefixLength; + result = prime * result + + ((queryString == null) ? 0 : queryString.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + FieldVals other = (FieldVals) obj; + if (fieldName == null) { + if (other.fieldName != null) + return false; + } else if (!fieldName.equals(other.fieldName)) + return false; + if (maxEdits != other.maxEdits) { + return false; + } + if (prefixLength != other.prefixLength) + return false; + if (queryString == null) { + if (other.queryString != null) + return false; + } else if (!queryString.equals(other.queryString)) + return false; + return true; + } + + + } + + /** + * Adds user input for "fuzzification" + * + * @param queryString The string which will be parsed by the analyzer and for which fuzzy variants will be parsed + */ + public void addTerms(String queryString, String fieldName) { + int maxEdits = (int) MIN_SIMILARITY; + if (maxEdits != MIN_SIMILARITY) { + throw new IllegalArgumentException("MIN_SIMILARITY must integer value between 0 and " + LevenshteinAutomata.MAXIMUM_SUPPORTED_DISTANCE + ", inclusive; got " + MIN_SIMILARITY); + } + fieldVals.add(new FieldVals(fieldName, maxEdits, queryString)); + } + + + private void addTerms(IndexReader reader, FieldVals f, ScoreTermQueue q) throws IOException { + if (f.queryString == null) return; + final Terms terms = MultiFields.getTerms(reader, f.fieldName); + if (terms == null) { + return; + } + try (TokenStream ts = analyzer.tokenStream(f.fieldName, f.queryString)) { + CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); + + int corpusNumDocs = reader.numDocs(); + HashSet processedTerms = new HashSet<>(); + ts.reset(); + while (ts.incrementToken()) { + String term = termAtt.toString(); + if (!processedTerms.contains(term)) { + processedTerms.add(term); + ScoreTermQueue variantsQ = new ScoreTermQueue(MAX_VARIANTS_PER_TERM); //maxNum variants considered for any one term + float minScore = 0; + Term startTerm = new Term(f.fieldName, term); + AttributeSource atts = new AttributeSource(); + MaxNonCompetitiveBoostAttribute maxBoostAtt = + atts.addAttribute(MaxNonCompetitiveBoostAttribute.class); + FuzzyTermsEnum fe = new FuzzyTermsEnum(terms, atts, startTerm, f.maxEdits, f.prefixLength, true); + //store the df so all variants use same idf + int df = reader.docFreq(startTerm); + int numVariants = 0; + int totalVariantDocFreqs = 0; + BytesRef possibleMatch; + BoostAttribute boostAtt = + fe.attributes().addAttribute(BoostAttribute.class); + while ((possibleMatch = fe.next()) != null) { + numVariants++; + totalVariantDocFreqs += fe.docFreq(); + float score = boostAtt.getBoost(); + if (variantsQ.size() < MAX_VARIANTS_PER_TERM || score > minScore) { + ScoreTerm st = new ScoreTerm(new Term(startTerm.field(), BytesRef.deepCopyOf(possibleMatch)), score, startTerm); + variantsQ.insertWithOverflow(st); + minScore = variantsQ.top().score; // maintain minScore + } + maxBoostAtt.setMaxNonCompetitiveBoost(variantsQ.size() >= MAX_VARIANTS_PER_TERM ? minScore : Float.NEGATIVE_INFINITY); + } + + if (numVariants > 0) { + int avgDf = totalVariantDocFreqs / numVariants; + if (df == 0)//no direct match we can use as df for all variants + { + df = avgDf; //use avg df of all variants + } + + // take the top variants (scored by edit distance) and reset the score + // to include an IDF factor then add to the global queue for ranking + // overall top query terms + int size = variantsQ.size(); + for (int i = 0; i < size; i++) { + ScoreTerm st = variantsQ.pop(); + if (st != null) { + st.score = (st.score * st.score) * idf(df, corpusNumDocs); + q.insertWithOverflow(st); + } + } + } + } + } + ts.end(); + } + } + + private float idf(int docFreq, int docCount) { + return (float)(Math.log((docCount+1)/(double)(docFreq+1)) + 1.0); + } + + private Query newTermQuery(IndexReader reader, Term term) throws IOException { + // we build an artificial TermContext that will give an overall df and ttf + // equal to 1 + TermContext context = new TermContext(reader.getContext()); + for (LeafReaderContext leafContext : reader.leaves()) { + Terms terms = leafContext.reader().terms(term.field()); + if (terms != null) { + TermsEnum termsEnum = terms.iterator(); + if (termsEnum.seekExact(term.bytes())) { + int freq = 1 - context.docFreq(); // we want the total df and ttf to be 1 + context.register(termsEnum.termState(), leafContext.ord, freq, freq); + } + } + } + return new TermQuery(term, context); + } + + @Override + public Query rewrite(IndexReader reader) throws IOException { + ScoreTermQueue q = new ScoreTermQueue(MAX_NUM_TERMS); + //load up the list of possible terms + for (FieldVals f : fieldVals) { + addTerms(reader, f, q); + } + + BooleanQuery.Builder bq = new BooleanQuery.Builder(); + + //create BooleanQueries to hold the variants for each token/field pair and ensure it + // has no coord factor + //Step 1: sort the termqueries by term/field + HashMap> variantQueries = new HashMap<>(); + int size = q.size(); + for (int i = 0; i < size; i++) { + ScoreTerm st = q.pop(); + if (st != null) { + ArrayList l = variantQueries.computeIfAbsent(st.fuzziedSourceTerm, k -> new ArrayList<>()); + l.add(st); + } + } + //Step 2: Organize the sorted termqueries into zero-coord scoring boolean queries + for (ArrayList variants : variantQueries.values()) { + if (variants.size() == 1) { + //optimize where only one selected variant + ScoreTerm st = variants.get(0); + Query tq = newTermQuery(reader, st.term); + // set the boost to a mix of IDF and score + bq.add(new BoostQuery(tq, st.score), BooleanClause.Occur.SHOULD); + } else { + BooleanQuery.Builder termVariants = new BooleanQuery.Builder(); + for (ScoreTerm st : variants) { + // found a match + Query tq = newTermQuery(reader, st.term); + // set the boost using the ScoreTerm's score + termVariants.add(new BoostQuery(tq, st.score), BooleanClause.Occur.SHOULD); // add to query + } + bq.add(termVariants.build(), BooleanClause.Occur.SHOULD); // add to query + } + } + //TODO possible alternative step 3 - organize above booleans into a new layer of field-based + // booleans with a minimum-should-match of NumFields-1? + return bq.build(); + } + + //Holds info for a fuzzy term variant - initially score is set to edit distance (for ranking best + // term variants) then is reset with IDF for use in ranking against all other + // terms/fields + private static class ScoreTerm { + public final Term term; + public float score; + final Term fuzziedSourceTerm; + + ScoreTerm(Term term, float score, Term fuzziedSourceTerm) { + this.term = term; + this.score = score; + this.fuzziedSourceTerm = fuzziedSourceTerm; + } + } + + private static class ScoreTermQueue extends PriorityQueue { + ScoreTermQueue(int size) { + super(size); + } + + /* (non-Javadoc) + * @see org.apache.lucene.util.PriorityQueue#lessThan(java.lang.Object, java.lang.Object) + */ + @Override + protected boolean lessThan(ScoreTerm termA, ScoreTerm termB) { + if (termA.score == termB.score) + return termA.term.compareTo(termB.term) > 0; + else + return termA.score < termB.score; + } + + } + + @Override + public String toString(String field) { + return null; + } + + @Override + public int hashCode() { + int prime = 31; + int result = classHash(); + result = prime * result + Objects.hashCode(analyzer); + result = prime * result + Objects.hashCode(fieldVals); + return result; + } + + @Override + public boolean equals(Object other) { + return sameClassAs(other) && + equalsTo(getClass().cast(other)); + } + + private boolean equalsTo(NearestFuzzyQuery other) { + return Objects.equals(analyzer, other.analyzer) && + Objects.equals(fieldVals, other.fieldVals); + } + +} diff --git a/lucene/classification/src/test/org/apache/lucene/classification/KNearestFuzzyClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/KNearestFuzzyClassifierTest.java index 1f70eb427a7..5c5122a1961 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/KNearestFuzzyClassifierTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/KNearestFuzzyClassifierTest.java @@ -28,7 +28,7 @@ import org.apache.lucene.util.BytesRef; import org.junit.Test; /** - * Testcase for {@link KNearestFuzzyClassifier} + * Tests for {@link KNearestFuzzyClassifier} */ public class KNearestFuzzyClassifierTest extends ClassificationTestBase {