mirror of https://github.com/apache/lucene.git
SOLR-7739 - applied patch from Alessandro Benedetti for integrating Lucene classification into Solr
(cherry picked from commit 5801caa
)
This commit is contained in:
parent
d4fe0c3849
commit
15fdfd3caa
|
@ -27,6 +27,7 @@
|
|||
<orderEntry type="module" module-name="expressions" />
|
||||
<orderEntry type="module" module-name="analysis-common" />
|
||||
<orderEntry type="module" module-name="lucene-core" />
|
||||
<orderEntry type="module" module-name="classification" />
|
||||
<orderEntry type="module" module-name="queryparser" />
|
||||
<orderEntry type="module" module-name="join" />
|
||||
<orderEntry type="module" module-name="sandbox" />
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
<orderEntry type="module" scope="TEST" module-name="solr-core" />
|
||||
<orderEntry type="module" scope="TEST" module-name="solrj" />
|
||||
<orderEntry type="module" scope="TEST" module-name="lucene-core" />
|
||||
<orderEntry type="module" scope="TEST" module-name="classification" />
|
||||
<orderEntry type="module" scope="TEST" module-name="analysis-common" />
|
||||
<orderEntry type="module" scope="TEST" module-name="queryparser" />
|
||||
<orderEntry type="module" scope="TEST" module-name="queries" />
|
||||
|
|
|
@ -108,6 +108,7 @@
|
|||
<pathelement location="${queryparser.jar}"/>
|
||||
<pathelement location="${join.jar}"/>
|
||||
<pathelement location="${sandbox.jar}"/>
|
||||
<pathelement location="${classification.jar}"/>
|
||||
</path>
|
||||
|
||||
<path id="solr.base.classpath">
|
||||
|
@ -169,7 +170,7 @@
|
|||
|
||||
<target name="prep-lucene-jars"
|
||||
depends="jar-lucene-core, jar-backward-codecs, jar-analyzers-phonetic, jar-analyzers-kuromoji, jar-codecs,jar-expressions, jar-suggest, jar-highlighter, jar-memory,
|
||||
jar-misc, jar-spatial-extras, jar-grouping, jar-queries, jar-queryparser, jar-join, jar-sandbox">
|
||||
jar-misc, jar-spatial-extras, jar-grouping, jar-queries, jar-queryparser, jar-join, jar-sandbox, jar-classification">
|
||||
<property name="solr.deps.compiled" value="true"/>
|
||||
</target>
|
||||
|
||||
|
@ -322,6 +323,7 @@
|
|||
<link offline="true" href="${lucene.javadoc.url}highlighter" packagelistloc="${lucenedocs}/highlighter"/>
|
||||
<link offline="true" href="${lucene.javadoc.url}memory" packagelistloc="${lucenedocs}/memory"/>
|
||||
<link offline="true" href="${lucene.javadoc.url}misc" packagelistloc="${lucenedocs}/misc"/>
|
||||
<link offline="true" href="${lucene.javadoc.url}classification" packagelistloc="${lucenedocs}/classification"/>
|
||||
<link offline="true" href="${lucene.javadoc.url}spatial-extras" packagelistloc="${lucenedocs}/spatial-extras"/>
|
||||
<links/>
|
||||
<link href=""/>
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
package org.apache.solr.update.processor;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.classification.ClassificationResult;
|
||||
import org.apache.lucene.classification.document.DocumentClassifier;
|
||||
import org.apache.lucene.classification.document.KNearestNeighborDocumentClassifier;
|
||||
import org.apache.lucene.classification.document.SimpleNaiveBayesDocumentClassifier;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.solr.common.SolrInputDocument;
|
||||
import org.apache.solr.schema.IndexSchema;
|
||||
import org.apache.solr.schema.SchemaField;
|
||||
import org.apache.solr.update.AddUpdateCommand;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* This Class is a Request Update Processor to classify the document in input and add a field
|
||||
* containing the class to the Document.
|
||||
* It uses the Lucene Document Classification module, see {@link DocumentClassifier}.
|
||||
*/
|
||||
class ClassificationUpdateProcessor
|
||||
extends UpdateRequestProcessor {
|
||||
|
||||
private String classFieldName; // the field to index the assigned class
|
||||
|
||||
private DocumentClassifier<BytesRef> classifier;
|
||||
|
||||
/**
|
||||
* Sole constructor
|
||||
*
|
||||
* @param inputFieldNames fields to be used as classifier's inputs
|
||||
* @param classFieldName field to be used as classifier's output
|
||||
* @param minDf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minDocFreq}, in case algorithm is {@code "knn"}
|
||||
* @param minTf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minTermFreq}, in case algorithm is {@code "knn"}
|
||||
* @param k setting for k nearest neighbors to analyze, in case algorithm is {@code "knn"}
|
||||
* @param algorithm the name of the classifier to use
|
||||
* @param next next update processor in the chain
|
||||
* @param indexReader index reader
|
||||
* @param schema schema
|
||||
*/
|
||||
public ClassificationUpdateProcessor(String[] inputFieldNames, String classFieldName, int minDf, int minTf, int k, String algorithm,
|
||||
UpdateRequestProcessor next, LeafReader indexReader, IndexSchema schema) {
|
||||
super(next);
|
||||
this.classFieldName = classFieldName;
|
||||
Map<String, Analyzer> field2analyzer = new HashMap<String, Analyzer>();
|
||||
for (String fieldName : inputFieldNames) {
|
||||
SchemaField fieldFromSolrSchema = schema.getField(fieldName);
|
||||
Analyzer indexAnalyzer = fieldFromSolrSchema.getType().getQueryAnalyzer();
|
||||
field2analyzer.put(fieldName, indexAnalyzer);
|
||||
}
|
||||
switch (algorithm) {
|
||||
case "knn":
|
||||
classifier = new KNearestNeighborDocumentClassifier(indexReader, null, null, k, minDf, minTf, classFieldName, field2analyzer, inputFieldNames);
|
||||
break;
|
||||
case "bayes":
|
||||
classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, classFieldName, field2analyzer, inputFieldNames);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param cmd the update command in input conaining the Document to classify
|
||||
* @throws IOException If there is a low-level I/O error
|
||||
*/
|
||||
@Override
|
||||
public void processAdd(AddUpdateCommand cmd)
|
||||
throws IOException {
|
||||
SolrInputDocument doc = cmd.getSolrInputDocument();
|
||||
Document luceneDocument = cmd.getLuceneDocument();
|
||||
String assignedClass;
|
||||
Object documentClass = doc.getFieldValue(classFieldName);
|
||||
if (documentClass == null) {
|
||||
ClassificationResult<BytesRef> classificationResult = classifier.assignClass(luceneDocument);
|
||||
if (classificationResult != null) {
|
||||
assignedClass = classificationResult.getAssignedClass().utf8ToString();
|
||||
doc.addField(classFieldName, assignedClass);
|
||||
}
|
||||
}
|
||||
super.processAdd(cmd);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,223 @@
|
|||
package org.apache.solr.update.processor;
|
||||
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.solr.common.SolrException;
|
||||
import org.apache.solr.common.params.SolrParams;
|
||||
import org.apache.solr.common.util.NamedList;
|
||||
import org.apache.solr.request.SolrQueryRequest;
|
||||
import org.apache.solr.response.SolrQueryResponse;
|
||||
import org.apache.solr.schema.IndexSchema;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* This class implements an UpdateProcessorFactory for the Classification Update Processor.
|
||||
* It takes in input a series of parameter that will be necessary to instantiate and use the Classifier
|
||||
*/
|
||||
public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessorFactory {
|
||||
|
||||
// Update Processor Config params
|
||||
private static final String INPUT_FIELDS_PARAM = "inputFields";
|
||||
private static final String CLASS_FIELD_PARAM = "classField";
|
||||
private static final String ALGORITHM_PARAM = "algorithm";
|
||||
private static final String KNN_MIN_TF_PARAM = "knn.minTf";
|
||||
private static final String KNN_MIN_DF_PARAM = "knn.minDf";
|
||||
private static final String KNN_K_PARAM = "knn.k";
|
||||
|
||||
//Update Processor Defaults
|
||||
private static final int DEFAULT_MIN_TF = 1;
|
||||
private static final int DEFAULT_MIN_DF = 1;
|
||||
private static final int DEFAULT_K = 10;
|
||||
private static final String DEFAULT_ALGORITHM = "knn";
|
||||
|
||||
private String[] inputFieldNames; // the array of fields to be sent to the Classifier
|
||||
|
||||
private String classFieldName; // the field containing the class for the Document
|
||||
|
||||
private String algorithm; // the Classification Algorithm to use - currently 'knn' or 'bayes'
|
||||
|
||||
private int minTf; // knn specific - the minimum Term Frequency for considering a term
|
||||
|
||||
private int minDf; // knn specific - the minimum Document Frequency for considering a term
|
||||
|
||||
private int k; // knn specific - thw window of top results to evaluate, when assgning the class
|
||||
|
||||
@Override
|
||||
public void init(final NamedList args) {
|
||||
if (args != null) {
|
||||
SolrParams params = SolrParams.toSolrParams(args);
|
||||
|
||||
String fieldNames = params.get(INPUT_FIELDS_PARAM);// must be a comma separated list of fields
|
||||
checkNotNull(INPUT_FIELDS_PARAM, fieldNames);
|
||||
inputFieldNames = fieldNames.split("\\,");
|
||||
|
||||
classFieldName = params.get(CLASS_FIELD_PARAM);
|
||||
checkNotNull(CLASS_FIELD_PARAM, classFieldName);
|
||||
|
||||
algorithm = params.get(ALGORITHM_PARAM);
|
||||
if (algorithm == null)
|
||||
algorithm = DEFAULT_ALGORITHM;
|
||||
|
||||
minTf = getIntParam(params, KNN_MIN_TF_PARAM, DEFAULT_MIN_TF);
|
||||
minDf = getIntParam(params, KNN_MIN_DF_PARAM, DEFAULT_MIN_DF);
|
||||
k = getIntParam(params, KNN_K_PARAM, DEFAULT_K);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Returns an Int parsed param or a default if the param is null
|
||||
*
|
||||
* @param params Solr params in input
|
||||
* @param name the param name
|
||||
* @param defaultValue the param default
|
||||
* @return the Int value for the param
|
||||
*/
|
||||
private int getIntParam(SolrParams params, String name, int defaultValue) {
|
||||
String paramString = params.get(name);
|
||||
int paramInt;
|
||||
if (paramString != null && !paramString.isEmpty()) {
|
||||
paramInt = Integer.parseInt(paramString);
|
||||
} else {
|
||||
paramInt = defaultValue;
|
||||
}
|
||||
return paramInt;
|
||||
}
|
||||
|
||||
private void checkNotNull(String paramName, Object param) {
|
||||
if (param == null) {
|
||||
throw new SolrException
|
||||
(SolrException.ErrorCode.SERVER_ERROR,
|
||||
"Classification UpdateProcessor '" + paramName + "' can not be null");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
|
||||
IndexSchema schema = req.getSchema();
|
||||
LeafReader leafReader = req.getSearcher().getLeafReader();
|
||||
return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, leafReader, schema);
|
||||
}
|
||||
|
||||
/**
|
||||
* get field names used as classifier's inputs
|
||||
*
|
||||
* @return the input field names
|
||||
*/
|
||||
public String[] getInputFieldNames() {
|
||||
return inputFieldNames;
|
||||
}
|
||||
|
||||
/**
|
||||
* set field names used as classifier's inputs
|
||||
*
|
||||
* @param inputFieldNames the input field names
|
||||
*/
|
||||
public void setInputFieldNames(String[] inputFieldNames) {
|
||||
this.inputFieldNames = inputFieldNames;
|
||||
}
|
||||
|
||||
/**
|
||||
* get field names used as classifier's output
|
||||
*
|
||||
* @return the output field name
|
||||
*/
|
||||
public String getClassFieldName() {
|
||||
return classFieldName;
|
||||
}
|
||||
|
||||
/**
|
||||
* set field names used as classifier's output
|
||||
*
|
||||
* @param classFieldName the output field name
|
||||
*/
|
||||
public void setClassFieldName(String classFieldName) {
|
||||
this.classFieldName = classFieldName;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the name of the classifier algorithm used
|
||||
*
|
||||
* @return the classifier algorithm used
|
||||
*/
|
||||
public String getAlgorithm() {
|
||||
return algorithm;
|
||||
}
|
||||
|
||||
/**
|
||||
* set the name of the classifier algorithm used
|
||||
*
|
||||
* @param algorithm the classifier algorithm used
|
||||
*/
|
||||
public void setAlgorithm(String algorithm) {
|
||||
this.algorithm = algorithm;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the min term frequency value to be used in case algorithm is {@code "knn"}
|
||||
*
|
||||
* @return the min term frequency
|
||||
*/
|
||||
public int getMinTf() {
|
||||
return minTf;
|
||||
}
|
||||
|
||||
/**
|
||||
* set the min term frequency value to be used in case algorithm is {@code "knn"}
|
||||
*
|
||||
* @param minTf the min term frequency
|
||||
*/
|
||||
public void setMinTf(int minTf) {
|
||||
this.minTf = minTf;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the min document frequency value to be used in case algorithm is {@code "knn"}
|
||||
*
|
||||
* @return the min document frequency
|
||||
*/
|
||||
public int getMinDf() {
|
||||
return minDf;
|
||||
}
|
||||
|
||||
/**
|
||||
* set the min document frequency value to be used in case algorithm is {@code "knn"}
|
||||
*
|
||||
* @param minDf the min document frequency
|
||||
*/
|
||||
public void setMinDf(int minDf) {
|
||||
this.minDf = minDf;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"}
|
||||
*
|
||||
* @return the no. of neighbors to analyze
|
||||
*/
|
||||
public int getK() {
|
||||
return k;
|
||||
}
|
||||
|
||||
/**
|
||||
* set the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"}
|
||||
*
|
||||
* @param k the no. of neighbors to analyze
|
||||
*/
|
||||
public void setK(int k) {
|
||||
this.k = k;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
contributor license agreements. See the NOTICE file distributed with
|
||||
this work for additional information regarding copyright ownership.
|
||||
The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
(the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
<!--
|
||||
Test Schema for a simple Classification Test
|
||||
-->
|
||||
<schema name="classification" version="1.1">
|
||||
<fieldType name="string" class="solr.StrField"/>
|
||||
<fieldType name="text_en" class="solr.TextField" positionIncrementGap="100">
|
||||
<analyzer type="index">
|
||||
<tokenizer class="solr.StandardTokenizerFactory"/>
|
||||
<filter class="solr.LowerCaseFilterFactory"/>
|
||||
<filter class="solr.EnglishPossessiveFilterFactory"/>
|
||||
<filter class="solr.PorterStemFilterFactory"/>
|
||||
</analyzer>
|
||||
<analyzer type="query">
|
||||
<tokenizer class="solr.StandardTokenizerFactory"/>
|
||||
<filter class="solr.LowerCaseFilterFactory"/>
|
||||
<filter class="solr.EnglishPossessiveFilterFactory"/>
|
||||
<filter class="solr.PorterStemFilterFactory"/>
|
||||
</analyzer>
|
||||
</fieldType>
|
||||
<field name="id" type="string" indexed="true" stored="true" required="true"/>
|
||||
<field name="title" type="text_en" indexed="true" stored="true"/>
|
||||
<field name="content" type="text_en" indexed="true" stored="true"/>
|
||||
<field name="author" type="string" indexed="true" stored="true"/>
|
||||
<field name="cat" type="string" indexed="true" stored="true"/>
|
||||
<uniqueKey>id</uniqueKey>
|
||||
</schema>
|
|
@ -0,0 +1,53 @@
|
|||
<?xml version="1.0" ?>
|
||||
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
contributor license agreements. See the NOTICE file distributed with
|
||||
this work for additional information regarding copyright ownership.
|
||||
The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
(the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
<!--
|
||||
Test Config for a simple Classification Update Request Processor Chain
|
||||
-->
|
||||
<config>
|
||||
<luceneMatchVersion>${tests.luceneMatchVersion:LATEST}</luceneMatchVersion>
|
||||
<xi:include xmlns:xi="http://www.w3.org/2001/XInclude" href="solrconfig.snippet.randomindexconfig.xml"/>
|
||||
<requestHandler name="standard" class="solr.StandardRequestHandler"></requestHandler>
|
||||
<directoryFactory name="DirectoryFactory" class="${solr.directoryFactory:solr.RAMDirectoryFactory}"/>
|
||||
<schemaFactory class="ClassicIndexSchemaFactory"/>
|
||||
|
||||
<updateHandler class="solr.DirectUpdateHandler2">
|
||||
<updateLog enable="${enable.update.log:true}">
|
||||
<str name="dir">${solr.ulog.dir:}</str>
|
||||
</updateLog>
|
||||
|
||||
<commitWithin>
|
||||
<softCommit>${solr.commitwithin.softcommit:true}</softCommit>
|
||||
</commitWithin>
|
||||
|
||||
</updateHandler>
|
||||
|
||||
<updateRequestProcessorChain name="classification">
|
||||
<processor class="solr.ClassificationUpdateProcessorFactory">
|
||||
<str name="inputFields">title,content,author</str>
|
||||
<str name="classField">cat</str>
|
||||
<!-- Knn algorithm specific-->
|
||||
<str name="algorithm">knn</str>
|
||||
<str name="knn.minTf">1</str>
|
||||
<str name="knn.minDf">1</str>
|
||||
<str name="knn.k">5</str>
|
||||
</processor>
|
||||
<processor class="solr.RunUpdateProcessorFactory"/>
|
||||
</updateRequestProcessorChain>
|
||||
</config>
|
|
@ -0,0 +1,234 @@
|
|||
package org.apache.solr.update.processor;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.solr.SolrTestCaseJ4;
|
||||
import org.apache.solr.common.SolrException;
|
||||
import org.apache.solr.common.params.MultiMapSolrParams;
|
||||
import org.apache.solr.common.params.SolrParams;
|
||||
import org.apache.solr.common.params.UpdateParams;
|
||||
import org.apache.solr.common.util.ContentStream;
|
||||
import org.apache.solr.common.util.ContentStreamBase;
|
||||
import org.apache.solr.common.util.NamedList;
|
||||
import org.apache.solr.handler.UpdateRequestHandler;
|
||||
import org.apache.solr.request.SolrQueryRequest;
|
||||
import org.apache.solr.request.SolrQueryRequestBase;
|
||||
import org.apache.solr.response.SolrQueryResponse;
|
||||
import org.apache.solr.search.SolrIndexSearcher;
|
||||
import org.junit.Before;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* Tests for {@link ClassificationUpdateProcessor} and {@link ClassificationUpdateProcessorFactory}
|
||||
*/
|
||||
public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
|
||||
// field names are used in accordance with the solrconfig and schema supplied
|
||||
private static final String ID = "id";
|
||||
private static final String TITLE = "title";
|
||||
private static final String CONTENT = "content";
|
||||
private static final String AUTHOR = "author";
|
||||
private static final String CLASS = "cat";
|
||||
|
||||
private static final String CHAIN = "classification";
|
||||
|
||||
|
||||
private ClassificationUpdateProcessorFactory cFactoryToTest = new ClassificationUpdateProcessorFactory();
|
||||
private NamedList args = new NamedList<String>();
|
||||
|
||||
@BeforeClass
|
||||
public static void beforeClass() throws Exception {
|
||||
System.setProperty("enable.update.log", "false");
|
||||
initCore("solrconfig-classification.xml", "schema-classification.xml");
|
||||
}
|
||||
|
||||
@Override
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
clearIndex();
|
||||
assertU(commit());
|
||||
}
|
||||
|
||||
@Before
|
||||
public void initArgs() {
|
||||
args.add("inputFields", "inputField1,inputField2");
|
||||
args.add("classField", "classField1");
|
||||
args.add("algorithm", "bayes");
|
||||
args.add("knn.k", "9");
|
||||
args.add("knn.minDf", "8");
|
||||
args.add("knn.minTf", "10");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFullInit() {
|
||||
cFactoryToTest.init(args);
|
||||
|
||||
String[] inputFieldNames = cFactoryToTest.getInputFieldNames();
|
||||
assertEquals("inputField1", inputFieldNames[0]);
|
||||
assertEquals("inputField2", inputFieldNames[1]);
|
||||
assertEquals("classField1", cFactoryToTest.getClassFieldName());
|
||||
assertEquals("bayes", cFactoryToTest.getAlgorithm());
|
||||
assertEquals(8, cFactoryToTest.getMinDf());
|
||||
assertEquals(10, cFactoryToTest.getMinTf());
|
||||
assertEquals(9, cFactoryToTest.getK());
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInitEmptyInputField() {
|
||||
args.removeAll("inputFields");
|
||||
try {
|
||||
cFactoryToTest.init(args);
|
||||
} catch (SolrException e) {
|
||||
assertEquals("Classification UpdateProcessor 'inputFields' can not be null", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInitEmptyClassField() {
|
||||
args.removeAll("classField");
|
||||
try {
|
||||
cFactoryToTest.init(args);
|
||||
} catch (SolrException e) {
|
||||
assertEquals("Classification UpdateProcessor 'classField' can not be null", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDefaults() {
|
||||
args.removeAll("algorithm");
|
||||
args.removeAll("knn.k");
|
||||
args.removeAll("knn.minDf");
|
||||
args.removeAll("knn.minTf");
|
||||
cFactoryToTest.init(args);
|
||||
assertEquals("knn", cFactoryToTest.getAlgorithm());
|
||||
assertEquals(1, cFactoryToTest.getMinDf());
|
||||
assertEquals(1, cFactoryToTest.getMinTf());
|
||||
assertEquals(10, cFactoryToTest.getK());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicClassification() throws Exception {
|
||||
prepareTrainedIndex();
|
||||
// To be classified,we index documents without a class and verify the expected one is returned
|
||||
addDoc(adoc(ID, "10",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5 ",
|
||||
AUTHOR, "Name1 Surname1"));
|
||||
addDoc(adoc(ID, "11",
|
||||
TITLE, "word1 word1",
|
||||
CONTENT, "word2 word2",
|
||||
AUTHOR, "Name Surname"));
|
||||
addDoc(commit());
|
||||
|
||||
Document doc10 = getDoc("10");
|
||||
assertEquals("class2", doc10.get(CLASS));
|
||||
Document doc11 = getDoc("11");
|
||||
assertEquals("class1", doc11.get(CLASS));
|
||||
}
|
||||
|
||||
/**
|
||||
* Index some example documents with a class manually assigned.
|
||||
* This will be our trained model.
|
||||
*
|
||||
* @throws Exception If there is a low-level I/O error
|
||||
*/
|
||||
private void prepareTrainedIndex() throws Exception {
|
||||
//class1
|
||||
addDoc(adoc(ID, "1",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2 word2 word2",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class1"));
|
||||
addDoc(adoc(ID, "2",
|
||||
TITLE, "word1 word1",
|
||||
CONTENT, "word2 word2",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class1"));
|
||||
addDoc(adoc(ID, "3",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class1"));
|
||||
addDoc(adoc(ID, "4",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2 word2 word2",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class1"));
|
||||
//class2
|
||||
addDoc(adoc(ID, "5",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
CLASS, "class2"));
|
||||
addDoc(adoc(ID, "6",
|
||||
TITLE, "word4 word4",
|
||||
CONTENT, "word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
CLASS, "class2"));
|
||||
addDoc(adoc(ID, "7",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5 word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
CLASS, "class2"));
|
||||
addDoc(adoc(ID, "8",
|
||||
TITLE, "word4",
|
||||
CONTENT, "word5 word5 word5 word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
CLASS, "class2"));
|
||||
addDoc(commit());
|
||||
}
|
||||
|
||||
private Document getDoc(String id) throws IOException {
|
||||
try (SolrQueryRequest req = req()) {
|
||||
SolrIndexSearcher searcher = req.getSearcher();
|
||||
TermQuery query = new TermQuery(new Term(ID, id));
|
||||
TopDocs doc1 = searcher.search(query, 1);
|
||||
ScoreDoc scoreDoc = doc1.scoreDocs[0];
|
||||
return searcher.doc(scoreDoc.doc);
|
||||
}
|
||||
}
|
||||
|
||||
static void addDoc(String doc) throws Exception {
|
||||
Map<String, String[]> params = new HashMap<>();
|
||||
MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
|
||||
params.put(UpdateParams.UPDATE_CHAIN, new String[]{CHAIN});
|
||||
SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(),
|
||||
(SolrParams) mmparams) {
|
||||
};
|
||||
|
||||
UpdateRequestHandler handler = new UpdateRequestHandler();
|
||||
handler.init(null);
|
||||
ArrayList<ContentStream> streams = new ArrayList<>(2);
|
||||
streams.add(new ContentStreamBase.StringStream(doc));
|
||||
req.setContentStreams(streams);
|
||||
handler.handleRequestBody(req, new SolrQueryResponse());
|
||||
req.close();
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue