mirror of https://github.com/apache/lucene.git
SOLR-7739 - applied patch from Alessandro Benedetti for integrating Lucene classification into Solr
This commit is contained in:
parent
56ca641b5b
commit
5801caab6c
|
@ -27,6 +27,7 @@
|
||||||
<orderEntry type="module" module-name="expressions" />
|
<orderEntry type="module" module-name="expressions" />
|
||||||
<orderEntry type="module" module-name="analysis-common" />
|
<orderEntry type="module" module-name="analysis-common" />
|
||||||
<orderEntry type="module" module-name="lucene-core" />
|
<orderEntry type="module" module-name="lucene-core" />
|
||||||
|
<orderEntry type="module" module-name="classification" />
|
||||||
<orderEntry type="module" module-name="queryparser" />
|
<orderEntry type="module" module-name="queryparser" />
|
||||||
<orderEntry type="module" module-name="join" />
|
<orderEntry type="module" module-name="join" />
|
||||||
<orderEntry type="module" module-name="sandbox" />
|
<orderEntry type="module" module-name="sandbox" />
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
<orderEntry type="module" scope="TEST" module-name="solr-core" />
|
<orderEntry type="module" scope="TEST" module-name="solr-core" />
|
||||||
<orderEntry type="module" scope="TEST" module-name="solrj" />
|
<orderEntry type="module" scope="TEST" module-name="solrj" />
|
||||||
<orderEntry type="module" scope="TEST" module-name="lucene-core" />
|
<orderEntry type="module" scope="TEST" module-name="lucene-core" />
|
||||||
|
<orderEntry type="module" scope="TEST" module-name="classification" />
|
||||||
<orderEntry type="module" scope="TEST" module-name="analysis-common" />
|
<orderEntry type="module" scope="TEST" module-name="analysis-common" />
|
||||||
<orderEntry type="module" scope="TEST" module-name="queryparser" />
|
<orderEntry type="module" scope="TEST" module-name="queryparser" />
|
||||||
<orderEntry type="module" scope="TEST" module-name="queries" />
|
<orderEntry type="module" scope="TEST" module-name="queries" />
|
||||||
|
|
|
@ -108,6 +108,7 @@
|
||||||
<pathelement location="${queryparser.jar}"/>
|
<pathelement location="${queryparser.jar}"/>
|
||||||
<pathelement location="${join.jar}"/>
|
<pathelement location="${join.jar}"/>
|
||||||
<pathelement location="${sandbox.jar}"/>
|
<pathelement location="${sandbox.jar}"/>
|
||||||
|
<pathelement location="${classification.jar}"/>
|
||||||
</path>
|
</path>
|
||||||
|
|
||||||
<path id="solr.base.classpath">
|
<path id="solr.base.classpath">
|
||||||
|
@ -169,7 +170,7 @@
|
||||||
|
|
||||||
<target name="prep-lucene-jars"
|
<target name="prep-lucene-jars"
|
||||||
depends="jar-lucene-core, jar-backward-codecs, jar-analyzers-phonetic, jar-analyzers-kuromoji, jar-codecs,jar-expressions, jar-suggest, jar-highlighter, jar-memory,
|
depends="jar-lucene-core, jar-backward-codecs, jar-analyzers-phonetic, jar-analyzers-kuromoji, jar-codecs,jar-expressions, jar-suggest, jar-highlighter, jar-memory,
|
||||||
jar-misc, jar-spatial-extras, jar-grouping, jar-queries, jar-queryparser, jar-join, jar-sandbox">
|
jar-misc, jar-spatial-extras, jar-grouping, jar-queries, jar-queryparser, jar-join, jar-sandbox, jar-classification">
|
||||||
<property name="solr.deps.compiled" value="true"/>
|
<property name="solr.deps.compiled" value="true"/>
|
||||||
</target>
|
</target>
|
||||||
|
|
||||||
|
@ -322,6 +323,7 @@
|
||||||
<link offline="true" href="${lucene.javadoc.url}highlighter" packagelistloc="${lucenedocs}/highlighter"/>
|
<link offline="true" href="${lucene.javadoc.url}highlighter" packagelistloc="${lucenedocs}/highlighter"/>
|
||||||
<link offline="true" href="${lucene.javadoc.url}memory" packagelistloc="${lucenedocs}/memory"/>
|
<link offline="true" href="${lucene.javadoc.url}memory" packagelistloc="${lucenedocs}/memory"/>
|
||||||
<link offline="true" href="${lucene.javadoc.url}misc" packagelistloc="${lucenedocs}/misc"/>
|
<link offline="true" href="${lucene.javadoc.url}misc" packagelistloc="${lucenedocs}/misc"/>
|
||||||
|
<link offline="true" href="${lucene.javadoc.url}classification" packagelistloc="${lucenedocs}/classification"/>
|
||||||
<link offline="true" href="${lucene.javadoc.url}spatial-extras" packagelistloc="${lucenedocs}/spatial-extras"/>
|
<link offline="true" href="${lucene.javadoc.url}spatial-extras" packagelistloc="${lucenedocs}/spatial-extras"/>
|
||||||
<links/>
|
<links/>
|
||||||
<link href=""/>
|
<link href=""/>
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
package org.apache.solr.update.processor;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
|
import org.apache.lucene.classification.ClassificationResult;
|
||||||
|
import org.apache.lucene.classification.document.DocumentClassifier;
|
||||||
|
import org.apache.lucene.classification.document.KNearestNeighborDocumentClassifier;
|
||||||
|
import org.apache.lucene.classification.document.SimpleNaiveBayesDocumentClassifier;
|
||||||
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.apache.solr.common.SolrInputDocument;
|
||||||
|
import org.apache.solr.schema.IndexSchema;
|
||||||
|
import org.apache.solr.schema.SchemaField;
|
||||||
|
import org.apache.solr.update.AddUpdateCommand;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This Class is a Request Update Processor to classify the document in input and add a field
|
||||||
|
* containing the class to the Document.
|
||||||
|
* It uses the Lucene Document Classification module, see {@link DocumentClassifier}.
|
||||||
|
*/
|
||||||
|
class ClassificationUpdateProcessor
|
||||||
|
extends UpdateRequestProcessor {
|
||||||
|
|
||||||
|
private String classFieldName; // the field to index the assigned class
|
||||||
|
|
||||||
|
private DocumentClassifier<BytesRef> classifier;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sole constructor
|
||||||
|
*
|
||||||
|
* @param inputFieldNames fields to be used as classifier's inputs
|
||||||
|
* @param classFieldName field to be used as classifier's output
|
||||||
|
* @param minDf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minDocFreq}, in case algorithm is {@code "knn"}
|
||||||
|
* @param minTf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minTermFreq}, in case algorithm is {@code "knn"}
|
||||||
|
* @param k setting for k nearest neighbors to analyze, in case algorithm is {@code "knn"}
|
||||||
|
* @param algorithm the name of the classifier to use
|
||||||
|
* @param next next update processor in the chain
|
||||||
|
* @param indexReader index reader
|
||||||
|
* @param schema schema
|
||||||
|
*/
|
||||||
|
public ClassificationUpdateProcessor(String[] inputFieldNames, String classFieldName, int minDf, int minTf, int k, String algorithm,
|
||||||
|
UpdateRequestProcessor next, LeafReader indexReader, IndexSchema schema) {
|
||||||
|
super(next);
|
||||||
|
this.classFieldName = classFieldName;
|
||||||
|
Map<String, Analyzer> field2analyzer = new HashMap<String, Analyzer>();
|
||||||
|
for (String fieldName : inputFieldNames) {
|
||||||
|
SchemaField fieldFromSolrSchema = schema.getField(fieldName);
|
||||||
|
Analyzer indexAnalyzer = fieldFromSolrSchema.getType().getQueryAnalyzer();
|
||||||
|
field2analyzer.put(fieldName, indexAnalyzer);
|
||||||
|
}
|
||||||
|
switch (algorithm) {
|
||||||
|
case "knn":
|
||||||
|
classifier = new KNearestNeighborDocumentClassifier(indexReader, null, null, k, minDf, minTf, classFieldName, field2analyzer, inputFieldNames);
|
||||||
|
break;
|
||||||
|
case "bayes":
|
||||||
|
classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, classFieldName, field2analyzer, inputFieldNames);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param cmd the update command in input conaining the Document to classify
|
||||||
|
* @throws IOException If there is a low-level I/O error
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void processAdd(AddUpdateCommand cmd)
|
||||||
|
throws IOException {
|
||||||
|
SolrInputDocument doc = cmd.getSolrInputDocument();
|
||||||
|
Document luceneDocument = cmd.getLuceneDocument();
|
||||||
|
String assignedClass;
|
||||||
|
Object documentClass = doc.getFieldValue(classFieldName);
|
||||||
|
if (documentClass == null) {
|
||||||
|
ClassificationResult<BytesRef> classificationResult = classifier.assignClass(luceneDocument);
|
||||||
|
if (classificationResult != null) {
|
||||||
|
assignedClass = classificationResult.getAssignedClass().utf8ToString();
|
||||||
|
doc.addField(classFieldName, assignedClass);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
super.processAdd(cmd);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,223 @@
|
||||||
|
package org.apache.solr.update.processor;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.solr.common.SolrException;
|
||||||
|
import org.apache.solr.common.params.SolrParams;
|
||||||
|
import org.apache.solr.common.util.NamedList;
|
||||||
|
import org.apache.solr.request.SolrQueryRequest;
|
||||||
|
import org.apache.solr.response.SolrQueryResponse;
|
||||||
|
import org.apache.solr.schema.IndexSchema;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class implements an UpdateProcessorFactory for the Classification Update Processor.
|
||||||
|
* It takes in input a series of parameter that will be necessary to instantiate and use the Classifier
|
||||||
|
*/
|
||||||
|
public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessorFactory {
|
||||||
|
|
||||||
|
// Update Processor Config params
|
||||||
|
private static final String INPUT_FIELDS_PARAM = "inputFields";
|
||||||
|
private static final String CLASS_FIELD_PARAM = "classField";
|
||||||
|
private static final String ALGORITHM_PARAM = "algorithm";
|
||||||
|
private static final String KNN_MIN_TF_PARAM = "knn.minTf";
|
||||||
|
private static final String KNN_MIN_DF_PARAM = "knn.minDf";
|
||||||
|
private static final String KNN_K_PARAM = "knn.k";
|
||||||
|
|
||||||
|
//Update Processor Defaults
|
||||||
|
private static final int DEFAULT_MIN_TF = 1;
|
||||||
|
private static final int DEFAULT_MIN_DF = 1;
|
||||||
|
private static final int DEFAULT_K = 10;
|
||||||
|
private static final String DEFAULT_ALGORITHM = "knn";
|
||||||
|
|
||||||
|
private String[] inputFieldNames; // the array of fields to be sent to the Classifier
|
||||||
|
|
||||||
|
private String classFieldName; // the field containing the class for the Document
|
||||||
|
|
||||||
|
private String algorithm; // the Classification Algorithm to use - currently 'knn' or 'bayes'
|
||||||
|
|
||||||
|
private int minTf; // knn specific - the minimum Term Frequency for considering a term
|
||||||
|
|
||||||
|
private int minDf; // knn specific - the minimum Document Frequency for considering a term
|
||||||
|
|
||||||
|
private int k; // knn specific - thw window of top results to evaluate, when assgning the class
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void init(final NamedList args) {
|
||||||
|
if (args != null) {
|
||||||
|
SolrParams params = SolrParams.toSolrParams(args);
|
||||||
|
|
||||||
|
String fieldNames = params.get(INPUT_FIELDS_PARAM);// must be a comma separated list of fields
|
||||||
|
checkNotNull(INPUT_FIELDS_PARAM, fieldNames);
|
||||||
|
inputFieldNames = fieldNames.split("\\,");
|
||||||
|
|
||||||
|
classFieldName = params.get(CLASS_FIELD_PARAM);
|
||||||
|
checkNotNull(CLASS_FIELD_PARAM, classFieldName);
|
||||||
|
|
||||||
|
algorithm = params.get(ALGORITHM_PARAM);
|
||||||
|
if (algorithm == null)
|
||||||
|
algorithm = DEFAULT_ALGORITHM;
|
||||||
|
|
||||||
|
minTf = getIntParam(params, KNN_MIN_TF_PARAM, DEFAULT_MIN_TF);
|
||||||
|
minDf = getIntParam(params, KNN_MIN_DF_PARAM, DEFAULT_MIN_DF);
|
||||||
|
k = getIntParam(params, KNN_K_PARAM, DEFAULT_K);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Returns an Int parsed param or a default if the param is null
|
||||||
|
*
|
||||||
|
* @param params Solr params in input
|
||||||
|
* @param name the param name
|
||||||
|
* @param defaultValue the param default
|
||||||
|
* @return the Int value for the param
|
||||||
|
*/
|
||||||
|
private int getIntParam(SolrParams params, String name, int defaultValue) {
|
||||||
|
String paramString = params.get(name);
|
||||||
|
int paramInt;
|
||||||
|
if (paramString != null && !paramString.isEmpty()) {
|
||||||
|
paramInt = Integer.parseInt(paramString);
|
||||||
|
} else {
|
||||||
|
paramInt = defaultValue;
|
||||||
|
}
|
||||||
|
return paramInt;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void checkNotNull(String paramName, Object param) {
|
||||||
|
if (param == null) {
|
||||||
|
throw new SolrException
|
||||||
|
(SolrException.ErrorCode.SERVER_ERROR,
|
||||||
|
"Classification UpdateProcessor '" + paramName + "' can not be null");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
|
||||||
|
IndexSchema schema = req.getSchema();
|
||||||
|
LeafReader leafReader = req.getSearcher().getLeafReader();
|
||||||
|
return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, leafReader, schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get field names used as classifier's inputs
|
||||||
|
*
|
||||||
|
* @return the input field names
|
||||||
|
*/
|
||||||
|
public String[] getInputFieldNames() {
|
||||||
|
return inputFieldNames;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set field names used as classifier's inputs
|
||||||
|
*
|
||||||
|
* @param inputFieldNames the input field names
|
||||||
|
*/
|
||||||
|
public void setInputFieldNames(String[] inputFieldNames) {
|
||||||
|
this.inputFieldNames = inputFieldNames;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get field names used as classifier's output
|
||||||
|
*
|
||||||
|
* @return the output field name
|
||||||
|
*/
|
||||||
|
public String getClassFieldName() {
|
||||||
|
return classFieldName;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set field names used as classifier's output
|
||||||
|
*
|
||||||
|
* @param classFieldName the output field name
|
||||||
|
*/
|
||||||
|
public void setClassFieldName(String classFieldName) {
|
||||||
|
this.classFieldName = classFieldName;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get the name of the classifier algorithm used
|
||||||
|
*
|
||||||
|
* @return the classifier algorithm used
|
||||||
|
*/
|
||||||
|
public String getAlgorithm() {
|
||||||
|
return algorithm;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set the name of the classifier algorithm used
|
||||||
|
*
|
||||||
|
* @param algorithm the classifier algorithm used
|
||||||
|
*/
|
||||||
|
public void setAlgorithm(String algorithm) {
|
||||||
|
this.algorithm = algorithm;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get the min term frequency value to be used in case algorithm is {@code "knn"}
|
||||||
|
*
|
||||||
|
* @return the min term frequency
|
||||||
|
*/
|
||||||
|
public int getMinTf() {
|
||||||
|
return minTf;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set the min term frequency value to be used in case algorithm is {@code "knn"}
|
||||||
|
*
|
||||||
|
* @param minTf the min term frequency
|
||||||
|
*/
|
||||||
|
public void setMinTf(int minTf) {
|
||||||
|
this.minTf = minTf;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get the min document frequency value to be used in case algorithm is {@code "knn"}
|
||||||
|
*
|
||||||
|
* @return the min document frequency
|
||||||
|
*/
|
||||||
|
public int getMinDf() {
|
||||||
|
return minDf;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set the min document frequency value to be used in case algorithm is {@code "knn"}
|
||||||
|
*
|
||||||
|
* @param minDf the min document frequency
|
||||||
|
*/
|
||||||
|
public void setMinDf(int minDf) {
|
||||||
|
this.minDf = minDf;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"}
|
||||||
|
*
|
||||||
|
* @return the no. of neighbors to analyze
|
||||||
|
*/
|
||||||
|
public int getK() {
|
||||||
|
return k;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"}
|
||||||
|
*
|
||||||
|
* @param k the no. of neighbors to analyze
|
||||||
|
*/
|
||||||
|
public void setK(int k) {
|
||||||
|
this.k = k;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8" ?>
|
||||||
|
<!--
|
||||||
|
Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
contributor license agreements. See the NOTICE file distributed with
|
||||||
|
this work for additional information regarding copyright ownership.
|
||||||
|
The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
(the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
-->
|
||||||
|
<!--
|
||||||
|
Test Schema for a simple Classification Test
|
||||||
|
-->
|
||||||
|
<schema name="classification" version="1.1">
|
||||||
|
<fieldType name="string" class="solr.StrField"/>
|
||||||
|
<fieldType name="text_en" class="solr.TextField" positionIncrementGap="100">
|
||||||
|
<analyzer type="index">
|
||||||
|
<tokenizer class="solr.StandardTokenizerFactory"/>
|
||||||
|
<filter class="solr.LowerCaseFilterFactory"/>
|
||||||
|
<filter class="solr.EnglishPossessiveFilterFactory"/>
|
||||||
|
<filter class="solr.PorterStemFilterFactory"/>
|
||||||
|
</analyzer>
|
||||||
|
<analyzer type="query">
|
||||||
|
<tokenizer class="solr.StandardTokenizerFactory"/>
|
||||||
|
<filter class="solr.LowerCaseFilterFactory"/>
|
||||||
|
<filter class="solr.EnglishPossessiveFilterFactory"/>
|
||||||
|
<filter class="solr.PorterStemFilterFactory"/>
|
||||||
|
</analyzer>
|
||||||
|
</fieldType>
|
||||||
|
<field name="id" type="string" indexed="true" stored="true" required="true"/>
|
||||||
|
<field name="title" type="text_en" indexed="true" stored="true"/>
|
||||||
|
<field name="content" type="text_en" indexed="true" stored="true"/>
|
||||||
|
<field name="author" type="string" indexed="true" stored="true"/>
|
||||||
|
<field name="cat" type="string" indexed="true" stored="true"/>
|
||||||
|
<uniqueKey>id</uniqueKey>
|
||||||
|
</schema>
|
|
@ -0,0 +1,53 @@
|
||||||
|
<?xml version="1.0" ?>
|
||||||
|
|
||||||
|
<!--
|
||||||
|
Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
contributor license agreements. See the NOTICE file distributed with
|
||||||
|
this work for additional information regarding copyright ownership.
|
||||||
|
The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
(the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
<!--
|
||||||
|
Test Config for a simple Classification Update Request Processor Chain
|
||||||
|
-->
|
||||||
|
<config>
|
||||||
|
<luceneMatchVersion>${tests.luceneMatchVersion:LATEST}</luceneMatchVersion>
|
||||||
|
<xi:include xmlns:xi="http://www.w3.org/2001/XInclude" href="solrconfig.snippet.randomindexconfig.xml"/>
|
||||||
|
<requestHandler name="standard" class="solr.StandardRequestHandler"></requestHandler>
|
||||||
|
<directoryFactory name="DirectoryFactory" class="${solr.directoryFactory:solr.RAMDirectoryFactory}"/>
|
||||||
|
<schemaFactory class="ClassicIndexSchemaFactory"/>
|
||||||
|
|
||||||
|
<updateHandler class="solr.DirectUpdateHandler2">
|
||||||
|
<updateLog enable="${enable.update.log:true}">
|
||||||
|
<str name="dir">${solr.ulog.dir:}</str>
|
||||||
|
</updateLog>
|
||||||
|
|
||||||
|
<commitWithin>
|
||||||
|
<softCommit>${solr.commitwithin.softcommit:true}</softCommit>
|
||||||
|
</commitWithin>
|
||||||
|
|
||||||
|
</updateHandler>
|
||||||
|
|
||||||
|
<updateRequestProcessorChain name="classification">
|
||||||
|
<processor class="solr.ClassificationUpdateProcessorFactory">
|
||||||
|
<str name="inputFields">title,content,author</str>
|
||||||
|
<str name="classField">cat</str>
|
||||||
|
<!-- Knn algorithm specific-->
|
||||||
|
<str name="algorithm">knn</str>
|
||||||
|
<str name="knn.minTf">1</str>
|
||||||
|
<str name="knn.minDf">1</str>
|
||||||
|
<str name="knn.k">5</str>
|
||||||
|
</processor>
|
||||||
|
<processor class="solr.RunUpdateProcessorFactory"/>
|
||||||
|
</updateRequestProcessorChain>
|
||||||
|
</config>
|
|
@ -0,0 +1,234 @@
|
||||||
|
package org.apache.solr.update.processor;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.index.Term;
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
import org.apache.lucene.search.TermQuery;
|
||||||
|
import org.apache.lucene.search.TopDocs;
|
||||||
|
import org.apache.solr.SolrTestCaseJ4;
|
||||||
|
import org.apache.solr.common.SolrException;
|
||||||
|
import org.apache.solr.common.params.MultiMapSolrParams;
|
||||||
|
import org.apache.solr.common.params.SolrParams;
|
||||||
|
import org.apache.solr.common.params.UpdateParams;
|
||||||
|
import org.apache.solr.common.util.ContentStream;
|
||||||
|
import org.apache.solr.common.util.ContentStreamBase;
|
||||||
|
import org.apache.solr.common.util.NamedList;
|
||||||
|
import org.apache.solr.handler.UpdateRequestHandler;
|
||||||
|
import org.apache.solr.request.SolrQueryRequest;
|
||||||
|
import org.apache.solr.request.SolrQueryRequestBase;
|
||||||
|
import org.apache.solr.response.SolrQueryResponse;
|
||||||
|
import org.apache.solr.search.SolrIndexSearcher;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.BeforeClass;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tests for {@link ClassificationUpdateProcessor} and {@link ClassificationUpdateProcessorFactory}
|
||||||
|
*/
|
||||||
|
public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
|
||||||
|
// field names are used in accordance with the solrconfig and schema supplied
|
||||||
|
private static final String ID = "id";
|
||||||
|
private static final String TITLE = "title";
|
||||||
|
private static final String CONTENT = "content";
|
||||||
|
private static final String AUTHOR = "author";
|
||||||
|
private static final String CLASS = "cat";
|
||||||
|
|
||||||
|
private static final String CHAIN = "classification";
|
||||||
|
|
||||||
|
|
||||||
|
private ClassificationUpdateProcessorFactory cFactoryToTest = new ClassificationUpdateProcessorFactory();
|
||||||
|
private NamedList args = new NamedList<String>();
|
||||||
|
|
||||||
|
@BeforeClass
|
||||||
|
public static void beforeClass() throws Exception {
|
||||||
|
System.setProperty("enable.update.log", "false");
|
||||||
|
initCore("solrconfig-classification.xml", "schema-classification.xml");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
@Before
|
||||||
|
public void setUp() throws Exception {
|
||||||
|
super.setUp();
|
||||||
|
clearIndex();
|
||||||
|
assertU(commit());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void initArgs() {
|
||||||
|
args.add("inputFields", "inputField1,inputField2");
|
||||||
|
args.add("classField", "classField1");
|
||||||
|
args.add("algorithm", "bayes");
|
||||||
|
args.add("knn.k", "9");
|
||||||
|
args.add("knn.minDf", "8");
|
||||||
|
args.add("knn.minTf", "10");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFullInit() {
|
||||||
|
cFactoryToTest.init(args);
|
||||||
|
|
||||||
|
String[] inputFieldNames = cFactoryToTest.getInputFieldNames();
|
||||||
|
assertEquals("inputField1", inputFieldNames[0]);
|
||||||
|
assertEquals("inputField2", inputFieldNames[1]);
|
||||||
|
assertEquals("classField1", cFactoryToTest.getClassFieldName());
|
||||||
|
assertEquals("bayes", cFactoryToTest.getAlgorithm());
|
||||||
|
assertEquals(8, cFactoryToTest.getMinDf());
|
||||||
|
assertEquals(10, cFactoryToTest.getMinTf());
|
||||||
|
assertEquals(9, cFactoryToTest.getK());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInitEmptyInputField() {
|
||||||
|
args.removeAll("inputFields");
|
||||||
|
try {
|
||||||
|
cFactoryToTest.init(args);
|
||||||
|
} catch (SolrException e) {
|
||||||
|
assertEquals("Classification UpdateProcessor 'inputFields' can not be null", e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInitEmptyClassField() {
|
||||||
|
args.removeAll("classField");
|
||||||
|
try {
|
||||||
|
cFactoryToTest.init(args);
|
||||||
|
} catch (SolrException e) {
|
||||||
|
assertEquals("Classification UpdateProcessor 'classField' can not be null", e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDefaults() {
|
||||||
|
args.removeAll("algorithm");
|
||||||
|
args.removeAll("knn.k");
|
||||||
|
args.removeAll("knn.minDf");
|
||||||
|
args.removeAll("knn.minTf");
|
||||||
|
cFactoryToTest.init(args);
|
||||||
|
assertEquals("knn", cFactoryToTest.getAlgorithm());
|
||||||
|
assertEquals(1, cFactoryToTest.getMinDf());
|
||||||
|
assertEquals(1, cFactoryToTest.getMinTf());
|
||||||
|
assertEquals(10, cFactoryToTest.getK());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBasicClassification() throws Exception {
|
||||||
|
prepareTrainedIndex();
|
||||||
|
// To be classified,we index documents without a class and verify the expected one is returned
|
||||||
|
addDoc(adoc(ID, "10",
|
||||||
|
TITLE, "word4 word4 word4",
|
||||||
|
CONTENT, "word5 word5 ",
|
||||||
|
AUTHOR, "Name1 Surname1"));
|
||||||
|
addDoc(adoc(ID, "11",
|
||||||
|
TITLE, "word1 word1",
|
||||||
|
CONTENT, "word2 word2",
|
||||||
|
AUTHOR, "Name Surname"));
|
||||||
|
addDoc(commit());
|
||||||
|
|
||||||
|
Document doc10 = getDoc("10");
|
||||||
|
assertEquals("class2", doc10.get(CLASS));
|
||||||
|
Document doc11 = getDoc("11");
|
||||||
|
assertEquals("class1", doc11.get(CLASS));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Index some example documents with a class manually assigned.
|
||||||
|
* This will be our trained model.
|
||||||
|
*
|
||||||
|
* @throws Exception If there is a low-level I/O error
|
||||||
|
*/
|
||||||
|
private void prepareTrainedIndex() throws Exception {
|
||||||
|
//class1
|
||||||
|
addDoc(adoc(ID, "1",
|
||||||
|
TITLE, "word1 word1 word1",
|
||||||
|
CONTENT, "word2 word2 word2",
|
||||||
|
AUTHOR, "Name Surname",
|
||||||
|
CLASS, "class1"));
|
||||||
|
addDoc(adoc(ID, "2",
|
||||||
|
TITLE, "word1 word1",
|
||||||
|
CONTENT, "word2 word2",
|
||||||
|
AUTHOR, "Name Surname",
|
||||||
|
CLASS, "class1"));
|
||||||
|
addDoc(adoc(ID, "3",
|
||||||
|
TITLE, "word1 word1 word1",
|
||||||
|
CONTENT, "word2",
|
||||||
|
AUTHOR, "Name Surname",
|
||||||
|
CLASS, "class1"));
|
||||||
|
addDoc(adoc(ID, "4",
|
||||||
|
TITLE, "word1 word1 word1",
|
||||||
|
CONTENT, "word2 word2 word2",
|
||||||
|
AUTHOR, "Name Surname",
|
||||||
|
CLASS, "class1"));
|
||||||
|
//class2
|
||||||
|
addDoc(adoc(ID, "5",
|
||||||
|
TITLE, "word4 word4 word4",
|
||||||
|
CONTENT, "word5 word5",
|
||||||
|
AUTHOR, "Name1 Surname1",
|
||||||
|
CLASS, "class2"));
|
||||||
|
addDoc(adoc(ID, "6",
|
||||||
|
TITLE, "word4 word4",
|
||||||
|
CONTENT, "word5",
|
||||||
|
AUTHOR, "Name1 Surname1",
|
||||||
|
CLASS, "class2"));
|
||||||
|
addDoc(adoc(ID, "7",
|
||||||
|
TITLE, "word4 word4 word4",
|
||||||
|
CONTENT, "word5 word5 word5",
|
||||||
|
AUTHOR, "Name1 Surname1",
|
||||||
|
CLASS, "class2"));
|
||||||
|
addDoc(adoc(ID, "8",
|
||||||
|
TITLE, "word4",
|
||||||
|
CONTENT, "word5 word5 word5 word5",
|
||||||
|
AUTHOR, "Name1 Surname1",
|
||||||
|
CLASS, "class2"));
|
||||||
|
addDoc(commit());
|
||||||
|
}
|
||||||
|
|
||||||
|
private Document getDoc(String id) throws IOException {
|
||||||
|
try (SolrQueryRequest req = req()) {
|
||||||
|
SolrIndexSearcher searcher = req.getSearcher();
|
||||||
|
TermQuery query = new TermQuery(new Term(ID, id));
|
||||||
|
TopDocs doc1 = searcher.search(query, 1);
|
||||||
|
ScoreDoc scoreDoc = doc1.scoreDocs[0];
|
||||||
|
return searcher.doc(scoreDoc.doc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void addDoc(String doc) throws Exception {
|
||||||
|
Map<String, String[]> params = new HashMap<>();
|
||||||
|
MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
|
||||||
|
params.put(UpdateParams.UPDATE_CHAIN, new String[]{CHAIN});
|
||||||
|
SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(),
|
||||||
|
(SolrParams) mmparams) {
|
||||||
|
};
|
||||||
|
|
||||||
|
UpdateRequestHandler handler = new UpdateRequestHandler();
|
||||||
|
handler.init(null);
|
||||||
|
ArrayList<ContentStream> streams = new ArrayList<>(2);
|
||||||
|
streams.add(new ContentStreamBase.StringStream(doc));
|
||||||
|
req.setContentStreams(streams);
|
||||||
|
handler.handleRequestBody(req, new SolrQueryResponse());
|
||||||
|
req.close();
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue