From 15fdfd3caa54c1e6d8cf3435ded285c3b65bea8c Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Tue, 15 Mar 2016 12:29:07 +0100 Subject: [PATCH] SOLR-7739 - applied patch from Alessandro Benedetti for integrating Lucene classification into Solr (cherry picked from commit 5801caa) --- .../idea/solr/core/src/java/solr-core.iml | 1 + .../idea/solr/core/src/solr-core-tests.iml | 1 + solr/common-build.xml | 4 +- .../ClassificationUpdateProcessor.java | 102 ++++++++ .../ClassificationUpdateProcessorFactory.java | 223 +++++++++++++++++ .../conf/schema-classification.xml | 43 ++++ .../conf/solrconfig-classification.xml | 53 ++++ ...ssificationUpdateProcessorFactoryTest.java | 234 ++++++++++++++++++ 8 files changed, 660 insertions(+), 1 deletion(-) create mode 100644 solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java create mode 100644 solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java create mode 100644 solr/core/src/test-files/solr/collection1/conf/schema-classification.xml create mode 100644 solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml create mode 100644 solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java diff --git a/dev-tools/idea/solr/core/src/java/solr-core.iml b/dev-tools/idea/solr/core/src/java/solr-core.iml index f03268c350b..822b24f6cab 100644 --- a/dev-tools/idea/solr/core/src/java/solr-core.iml +++ b/dev-tools/idea/solr/core/src/java/solr-core.iml @@ -27,6 +27,7 @@ + diff --git a/dev-tools/idea/solr/core/src/solr-core-tests.iml b/dev-tools/idea/solr/core/src/solr-core-tests.iml index c9f722a1212..56f768b49f3 100644 --- a/dev-tools/idea/solr/core/src/solr-core-tests.iml +++ b/dev-tools/idea/solr/core/src/solr-core-tests.iml @@ -21,6 +21,7 @@ + diff --git a/solr/common-build.xml b/solr/common-build.xml index 6a069286077..78e10aabac1 100644 --- a/solr/common-build.xml +++ b/solr/common-build.xml @@ -108,6 +108,7 @@ + @@ -169,7 +170,7 @@ + jar-misc, jar-spatial-extras, jar-grouping, jar-queries, jar-queryparser, jar-join, jar-sandbox, jar-classification"> @@ -322,6 +323,7 @@ + diff --git a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java new file mode 100644 index 00000000000..b752565cd6d --- /dev/null +++ b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java @@ -0,0 +1,102 @@ +package org.apache.solr.update.processor; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.classification.ClassificationResult; +import org.apache.lucene.classification.document.DocumentClassifier; +import org.apache.lucene.classification.document.KNearestNeighborDocumentClassifier; +import org.apache.lucene.classification.document.SimpleNaiveBayesDocumentClassifier; +import org.apache.lucene.document.Document; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.util.BytesRef; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.schema.IndexSchema; +import org.apache.solr.schema.SchemaField; +import org.apache.solr.update.AddUpdateCommand; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * This Class is a Request Update Processor to classify the document in input and add a field + * containing the class to the Document. + * It uses the Lucene Document Classification module, see {@link DocumentClassifier}. + */ +class ClassificationUpdateProcessor + extends UpdateRequestProcessor { + + private String classFieldName; // the field to index the assigned class + + private DocumentClassifier classifier; + + /** + * Sole constructor + * + * @param inputFieldNames fields to be used as classifier's inputs + * @param classFieldName field to be used as classifier's output + * @param minDf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minDocFreq}, in case algorithm is {@code "knn"} + * @param minTf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minTermFreq}, in case algorithm is {@code "knn"} + * @param k setting for k nearest neighbors to analyze, in case algorithm is {@code "knn"} + * @param algorithm the name of the classifier to use + * @param next next update processor in the chain + * @param indexReader index reader + * @param schema schema + */ + public ClassificationUpdateProcessor(String[] inputFieldNames, String classFieldName, int minDf, int minTf, int k, String algorithm, + UpdateRequestProcessor next, LeafReader indexReader, IndexSchema schema) { + super(next); + this.classFieldName = classFieldName; + Map field2analyzer = new HashMap(); + for (String fieldName : inputFieldNames) { + SchemaField fieldFromSolrSchema = schema.getField(fieldName); + Analyzer indexAnalyzer = fieldFromSolrSchema.getType().getQueryAnalyzer(); + field2analyzer.put(fieldName, indexAnalyzer); + } + switch (algorithm) { + case "knn": + classifier = new KNearestNeighborDocumentClassifier(indexReader, null, null, k, minDf, minTf, classFieldName, field2analyzer, inputFieldNames); + break; + case "bayes": + classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, classFieldName, field2analyzer, inputFieldNames); + break; + } + } + + /** + * @param cmd the update command in input conaining the Document to classify + * @throws IOException If there is a low-level I/O error + */ + @Override + public void processAdd(AddUpdateCommand cmd) + throws IOException { + SolrInputDocument doc = cmd.getSolrInputDocument(); + Document luceneDocument = cmd.getLuceneDocument(); + String assignedClass; + Object documentClass = doc.getFieldValue(classFieldName); + if (documentClass == null) { + ClassificationResult classificationResult = classifier.assignClass(luceneDocument); + if (classificationResult != null) { + assignedClass = classificationResult.getAssignedClass().utf8ToString(); + doc.addField(classFieldName, assignedClass); + } + } + super.processAdd(cmd); + } +} diff --git a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java new file mode 100644 index 00000000000..79b3240f005 --- /dev/null +++ b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java @@ -0,0 +1,223 @@ +package org.apache.solr.update.processor; + +import org.apache.lucene.index.LeafReader; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.params.SolrParams; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.schema.IndexSchema; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * This class implements an UpdateProcessorFactory for the Classification Update Processor. + * It takes in input a series of parameter that will be necessary to instantiate and use the Classifier + */ +public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessorFactory { + + // Update Processor Config params + private static final String INPUT_FIELDS_PARAM = "inputFields"; + private static final String CLASS_FIELD_PARAM = "classField"; + private static final String ALGORITHM_PARAM = "algorithm"; + private static final String KNN_MIN_TF_PARAM = "knn.minTf"; + private static final String KNN_MIN_DF_PARAM = "knn.minDf"; + private static final String KNN_K_PARAM = "knn.k"; + + //Update Processor Defaults + private static final int DEFAULT_MIN_TF = 1; + private static final int DEFAULT_MIN_DF = 1; + private static final int DEFAULT_K = 10; + private static final String DEFAULT_ALGORITHM = "knn"; + + private String[] inputFieldNames; // the array of fields to be sent to the Classifier + + private String classFieldName; // the field containing the class for the Document + + private String algorithm; // the Classification Algorithm to use - currently 'knn' or 'bayes' + + private int minTf; // knn specific - the minimum Term Frequency for considering a term + + private int minDf; // knn specific - the minimum Document Frequency for considering a term + + private int k; // knn specific - thw window of top results to evaluate, when assgning the class + + @Override + public void init(final NamedList args) { + if (args != null) { + SolrParams params = SolrParams.toSolrParams(args); + + String fieldNames = params.get(INPUT_FIELDS_PARAM);// must be a comma separated list of fields + checkNotNull(INPUT_FIELDS_PARAM, fieldNames); + inputFieldNames = fieldNames.split("\\,"); + + classFieldName = params.get(CLASS_FIELD_PARAM); + checkNotNull(CLASS_FIELD_PARAM, classFieldName); + + algorithm = params.get(ALGORITHM_PARAM); + if (algorithm == null) + algorithm = DEFAULT_ALGORITHM; + + minTf = getIntParam(params, KNN_MIN_TF_PARAM, DEFAULT_MIN_TF); + minDf = getIntParam(params, KNN_MIN_DF_PARAM, DEFAULT_MIN_DF); + k = getIntParam(params, KNN_K_PARAM, DEFAULT_K); + } + } + + /* + * Returns an Int parsed param or a default if the param is null + * + * @param params Solr params in input + * @param name the param name + * @param defaultValue the param default + * @return the Int value for the param + */ + private int getIntParam(SolrParams params, String name, int defaultValue) { + String paramString = params.get(name); + int paramInt; + if (paramString != null && !paramString.isEmpty()) { + paramInt = Integer.parseInt(paramString); + } else { + paramInt = defaultValue; + } + return paramInt; + } + + private void checkNotNull(String paramName, Object param) { + if (param == null) { + throw new SolrException + (SolrException.ErrorCode.SERVER_ERROR, + "Classification UpdateProcessor '" + paramName + "' can not be null"); + } + } + + @Override + public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) { + IndexSchema schema = req.getSchema(); + LeafReader leafReader = req.getSearcher().getLeafReader(); + return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, leafReader, schema); + } + + /** + * get field names used as classifier's inputs + * + * @return the input field names + */ + public String[] getInputFieldNames() { + return inputFieldNames; + } + + /** + * set field names used as classifier's inputs + * + * @param inputFieldNames the input field names + */ + public void setInputFieldNames(String[] inputFieldNames) { + this.inputFieldNames = inputFieldNames; + } + + /** + * get field names used as classifier's output + * + * @return the output field name + */ + public String getClassFieldName() { + return classFieldName; + } + + /** + * set field names used as classifier's output + * + * @param classFieldName the output field name + */ + public void setClassFieldName(String classFieldName) { + this.classFieldName = classFieldName; + } + + /** + * get the name of the classifier algorithm used + * + * @return the classifier algorithm used + */ + public String getAlgorithm() { + return algorithm; + } + + /** + * set the name of the classifier algorithm used + * + * @param algorithm the classifier algorithm used + */ + public void setAlgorithm(String algorithm) { + this.algorithm = algorithm; + } + + /** + * get the min term frequency value to be used in case algorithm is {@code "knn"} + * + * @return the min term frequency + */ + public int getMinTf() { + return minTf; + } + + /** + * set the min term frequency value to be used in case algorithm is {@code "knn"} + * + * @param minTf the min term frequency + */ + public void setMinTf(int minTf) { + this.minTf = minTf; + } + + /** + * get the min document frequency value to be used in case algorithm is {@code "knn"} + * + * @return the min document frequency + */ + public int getMinDf() { + return minDf; + } + + /** + * set the min document frequency value to be used in case algorithm is {@code "knn"} + * + * @param minDf the min document frequency + */ + public void setMinDf(int minDf) { + this.minDf = minDf; + } + + /** + * get the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"} + * + * @return the no. of neighbors to analyze + */ + public int getK() { + return k; + } + + /** + * set the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"} + * + * @param k the no. of neighbors to analyze + */ + public void setK(int k) { + this.k = k; + } +} diff --git a/solr/core/src/test-files/solr/collection1/conf/schema-classification.xml b/solr/core/src/test-files/solr/collection1/conf/schema-classification.xml new file mode 100644 index 00000000000..89c27a6766a --- /dev/null +++ b/solr/core/src/test-files/solr/collection1/conf/schema-classification.xml @@ -0,0 +1,43 @@ + + + + + + + + + + + + + + + + + + + + + + + + + id + diff --git a/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml b/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml new file mode 100644 index 00000000000..3656335d184 --- /dev/null +++ b/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml @@ -0,0 +1,53 @@ + + + + + + + ${tests.luceneMatchVersion:LATEST} + + + + + + + + ${solr.ulog.dir:} + + + + ${solr.commitwithin.softcommit:true} + + + + + + + title,content,author + cat + + knn + 1 + 1 + 5 + + + + diff --git a/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java new file mode 100644 index 00000000000..27d8dca71ee --- /dev/null +++ b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java @@ -0,0 +1,234 @@ +package org.apache.solr.update.processor; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; + +import org.apache.lucene.document.Document; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.solr.SolrTestCaseJ4; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.params.MultiMapSolrParams; +import org.apache.solr.common.params.SolrParams; +import org.apache.solr.common.params.UpdateParams; +import org.apache.solr.common.util.ContentStream; +import org.apache.solr.common.util.ContentStreamBase; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.handler.UpdateRequestHandler; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.request.SolrQueryRequestBase; +import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.search.SolrIndexSearcher; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * Tests for {@link ClassificationUpdateProcessor} and {@link ClassificationUpdateProcessorFactory} + */ +public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 { + // field names are used in accordance with the solrconfig and schema supplied + private static final String ID = "id"; + private static final String TITLE = "title"; + private static final String CONTENT = "content"; + private static final String AUTHOR = "author"; + private static final String CLASS = "cat"; + + private static final String CHAIN = "classification"; + + + private ClassificationUpdateProcessorFactory cFactoryToTest = new ClassificationUpdateProcessorFactory(); + private NamedList args = new NamedList(); + + @BeforeClass + public static void beforeClass() throws Exception { + System.setProperty("enable.update.log", "false"); + initCore("solrconfig-classification.xml", "schema-classification.xml"); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + clearIndex(); + assertU(commit()); + } + + @Before + public void initArgs() { + args.add("inputFields", "inputField1,inputField2"); + args.add("classField", "classField1"); + args.add("algorithm", "bayes"); + args.add("knn.k", "9"); + args.add("knn.minDf", "8"); + args.add("knn.minTf", "10"); + } + + @Test + public void testFullInit() { + cFactoryToTest.init(args); + + String[] inputFieldNames = cFactoryToTest.getInputFieldNames(); + assertEquals("inputField1", inputFieldNames[0]); + assertEquals("inputField2", inputFieldNames[1]); + assertEquals("classField1", cFactoryToTest.getClassFieldName()); + assertEquals("bayes", cFactoryToTest.getAlgorithm()); + assertEquals(8, cFactoryToTest.getMinDf()); + assertEquals(10, cFactoryToTest.getMinTf()); + assertEquals(9, cFactoryToTest.getK()); + + } + + @Test + public void testInitEmptyInputField() { + args.removeAll("inputFields"); + try { + cFactoryToTest.init(args); + } catch (SolrException e) { + assertEquals("Classification UpdateProcessor 'inputFields' can not be null", e.getMessage()); + } + } + + @Test + public void testInitEmptyClassField() { + args.removeAll("classField"); + try { + cFactoryToTest.init(args); + } catch (SolrException e) { + assertEquals("Classification UpdateProcessor 'classField' can not be null", e.getMessage()); + } + } + + @Test + public void testDefaults() { + args.removeAll("algorithm"); + args.removeAll("knn.k"); + args.removeAll("knn.minDf"); + args.removeAll("knn.minTf"); + cFactoryToTest.init(args); + assertEquals("knn", cFactoryToTest.getAlgorithm()); + assertEquals(1, cFactoryToTest.getMinDf()); + assertEquals(1, cFactoryToTest.getMinTf()); + assertEquals(10, cFactoryToTest.getK()); + } + + @Test + public void testBasicClassification() throws Exception { + prepareTrainedIndex(); + // To be classified,we index documents without a class and verify the expected one is returned + addDoc(adoc(ID, "10", + TITLE, "word4 word4 word4", + CONTENT, "word5 word5 ", + AUTHOR, "Name1 Surname1")); + addDoc(adoc(ID, "11", + TITLE, "word1 word1", + CONTENT, "word2 word2", + AUTHOR, "Name Surname")); + addDoc(commit()); + + Document doc10 = getDoc("10"); + assertEquals("class2", doc10.get(CLASS)); + Document doc11 = getDoc("11"); + assertEquals("class1", doc11.get(CLASS)); + } + + /** + * Index some example documents with a class manually assigned. + * This will be our trained model. + * + * @throws Exception If there is a low-level I/O error + */ + private void prepareTrainedIndex() throws Exception { + //class1 + addDoc(adoc(ID, "1", + TITLE, "word1 word1 word1", + CONTENT, "word2 word2 word2", + AUTHOR, "Name Surname", + CLASS, "class1")); + addDoc(adoc(ID, "2", + TITLE, "word1 word1", + CONTENT, "word2 word2", + AUTHOR, "Name Surname", + CLASS, "class1")); + addDoc(adoc(ID, "3", + TITLE, "word1 word1 word1", + CONTENT, "word2", + AUTHOR, "Name Surname", + CLASS, "class1")); + addDoc(adoc(ID, "4", + TITLE, "word1 word1 word1", + CONTENT, "word2 word2 word2", + AUTHOR, "Name Surname", + CLASS, "class1")); + //class2 + addDoc(adoc(ID, "5", + TITLE, "word4 word4 word4", + CONTENT, "word5 word5", + AUTHOR, "Name1 Surname1", + CLASS, "class2")); + addDoc(adoc(ID, "6", + TITLE, "word4 word4", + CONTENT, "word5", + AUTHOR, "Name1 Surname1", + CLASS, "class2")); + addDoc(adoc(ID, "7", + TITLE, "word4 word4 word4", + CONTENT, "word5 word5 word5", + AUTHOR, "Name1 Surname1", + CLASS, "class2")); + addDoc(adoc(ID, "8", + TITLE, "word4", + CONTENT, "word5 word5 word5 word5", + AUTHOR, "Name1 Surname1", + CLASS, "class2")); + addDoc(commit()); + } + + private Document getDoc(String id) throws IOException { + try (SolrQueryRequest req = req()) { + SolrIndexSearcher searcher = req.getSearcher(); + TermQuery query = new TermQuery(new Term(ID, id)); + TopDocs doc1 = searcher.search(query, 1); + ScoreDoc scoreDoc = doc1.scoreDocs[0]; + return searcher.doc(scoreDoc.doc); + } + } + + static void addDoc(String doc) throws Exception { + Map params = new HashMap<>(); + MultiMapSolrParams mmparams = new MultiMapSolrParams(params); + params.put(UpdateParams.UPDATE_CHAIN, new String[]{CHAIN}); + SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(), + (SolrParams) mmparams) { + }; + + UpdateRequestHandler handler = new UpdateRequestHandler(); + handler.init(null); + ArrayList streams = new ArrayList<>(2); + streams.add(new ContentStreamBase.StringStream(doc)); + req.setContentStreams(streams); + handler.handleRequestBody(req, new SolrQueryResponse()); + req.close(); + } +}