mirror of https://github.com/apache/lucene.git
SOLR-8871 - various improvements to ClassificationURP
This commit is contained in:
parent
e9e4715dd2
commit
5ad741eef8
|
@ -16,22 +16,12 @@
|
|||
*/
|
||||
package org.apache.solr.uima.processor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.util.LuceneTestCase.Slow;
|
||||
import org.apache.solr.SolrTestCaseJ4;
|
||||
import org.apache.solr.common.SolrException;
|
||||
import org.apache.solr.common.params.MultiMapSolrParams;
|
||||
import org.apache.solr.common.params.SolrParams;
|
||||
import org.apache.solr.common.params.UpdateParams;
|
||||
import org.apache.solr.common.util.ContentStream;
|
||||
import org.apache.solr.common.util.ContentStreamBase;
|
||||
import org.apache.solr.core.SolrCore;
|
||||
import org.apache.solr.handler.UpdateRequestHandler;
|
||||
import org.apache.solr.request.SolrQueryRequestBase;
|
||||
import org.apache.solr.response.SolrQueryResponse;
|
||||
import org.apache.solr.uima.processor.SolrUIMAConfiguration.MapField;
|
||||
import org.apache.solr.update.processor.UpdateRequestProcessor;
|
||||
import org.apache.solr.update.processor.UpdateRequestProcessorChain;
|
||||
|
@ -188,19 +178,4 @@ public class UIMAUpdateRequestProcessorTest extends SolrTestCaseJ4 {
|
|||
}
|
||||
}
|
||||
|
||||
private void addDoc(String chain, String doc) throws Exception {
|
||||
Map<String, String[]> params = new HashMap<>();
|
||||
params.put(UpdateParams.UPDATE_CHAIN, new String[] { chain });
|
||||
MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
|
||||
SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(), (SolrParams) mmparams) {
|
||||
};
|
||||
|
||||
UpdateRequestHandler handler = new UpdateRequestHandler();
|
||||
handler.init(null);
|
||||
ArrayList<ContentStream> streams = new ArrayList<>(2);
|
||||
streams.add(new ContentStreamBase.StringStream(doc));
|
||||
req.setContentStreams(streams);
|
||||
handler.handleRequestBody(req, new SolrQueryResponse());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.solr.update.processor;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
|
@ -33,6 +34,7 @@ import org.apache.solr.common.SolrInputDocument;
|
|||
import org.apache.solr.schema.IndexSchema;
|
||||
import org.apache.solr.schema.SchemaField;
|
||||
import org.apache.solr.update.AddUpdateCommand;
|
||||
import org.apache.solr.update.processor.ClassificationUpdateProcessorFactory.Algorithm;
|
||||
|
||||
/**
|
||||
* This Class is a Request Update Processor to classify the document in input and add a field
|
||||
|
@ -42,43 +44,54 @@ import org.apache.solr.update.AddUpdateCommand;
|
|||
class ClassificationUpdateProcessor
|
||||
extends UpdateRequestProcessor {
|
||||
|
||||
private String classFieldName; // the field to index the assigned class
|
||||
|
||||
private final String trainingClassField;
|
||||
private final String predictedClassField;
|
||||
private final int maxOutputClasses;
|
||||
private DocumentClassifier<BytesRef> classifier;
|
||||
|
||||
/**
|
||||
* Sole constructor
|
||||
*
|
||||
* @param inputFieldNames fields to be used as classifier's inputs
|
||||
* @param classFieldName field to be used as classifier's output
|
||||
* @param minDf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minDocFreq}, in case algorithm is {@code "knn"}
|
||||
* @param minTf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minTermFreq}, in case algorithm is {@code "knn"}
|
||||
* @param k setting for k nearest neighbors to analyze, in case algorithm is {@code "knn"}
|
||||
* @param algorithm the name of the classifier to use
|
||||
* @param classificationParams classification advanced params
|
||||
* @param next next update processor in the chain
|
||||
* @param indexReader index reader
|
||||
* @param schema schema
|
||||
*/
|
||||
public ClassificationUpdateProcessor(String[] inputFieldNames, String classFieldName, int minDf, int minTf, int k, String algorithm,
|
||||
UpdateRequestProcessor next, IndexReader indexReader, IndexSchema schema) {
|
||||
public ClassificationUpdateProcessor(ClassificationUpdateProcessorParams classificationParams, UpdateRequestProcessor next, IndexReader indexReader, IndexSchema schema) {
|
||||
super(next);
|
||||
this.classFieldName = classFieldName;
|
||||
Map<String, Analyzer> field2analyzer = new HashMap<String, Analyzer>();
|
||||
this.trainingClassField = classificationParams.getTrainingClassField();
|
||||
this.predictedClassField = classificationParams.getPredictedClassField();
|
||||
this.maxOutputClasses = classificationParams.getMaxPredictedClasses();
|
||||
String[] inputFieldNamesWithBoost = classificationParams.getInputFieldNames();
|
||||
Algorithm classificationAlgorithm = classificationParams.getAlgorithm();
|
||||
|
||||
Map<String, Analyzer> field2analyzer = new HashMap<>();
|
||||
String[] inputFieldNames = this.removeBoost(inputFieldNamesWithBoost);
|
||||
for (String fieldName : inputFieldNames) {
|
||||
SchemaField fieldFromSolrSchema = schema.getField(fieldName);
|
||||
Analyzer indexAnalyzer = fieldFromSolrSchema.getType().getQueryAnalyzer();
|
||||
field2analyzer.put(fieldName, indexAnalyzer);
|
||||
}
|
||||
switch (algorithm) {
|
||||
case "knn":
|
||||
classifier = new KNearestNeighborDocumentClassifier(indexReader, null, null, k, minDf, minTf, classFieldName, field2analyzer, inputFieldNames);
|
||||
switch (classificationAlgorithm) {
|
||||
case KNN:
|
||||
classifier = new KNearestNeighborDocumentClassifier(indexReader, null, classificationParams.getTrainingFilterQuery(), classificationParams.getK(), classificationParams.getMinDf(), classificationParams.getMinTf(), trainingClassField, field2analyzer, inputFieldNamesWithBoost);
|
||||
break;
|
||||
case "bayes":
|
||||
classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, classFieldName, field2analyzer, inputFieldNames);
|
||||
case BAYES:
|
||||
classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, trainingClassField, field2analyzer, inputFieldNamesWithBoost);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private String[] removeBoost(String[] inputFieldNamesWithBoost) {
|
||||
String[] inputFieldNames = new String[inputFieldNamesWithBoost.length];
|
||||
for (int i = 0; i < inputFieldNamesWithBoost.length; i++) {
|
||||
String singleFieldNameWithBoost = inputFieldNamesWithBoost[i];
|
||||
String[] fieldName2boost = singleFieldNameWithBoost.split("\\^");
|
||||
inputFieldNames[i] = fieldName2boost[0];
|
||||
}
|
||||
return inputFieldNames;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param cmd the update command in input containing the Document to classify
|
||||
* @throws IOException If there is a low-level I/O error
|
||||
|
@ -89,12 +102,14 @@ class ClassificationUpdateProcessor
|
|||
SolrInputDocument doc = cmd.getSolrInputDocument();
|
||||
Document luceneDocument = cmd.getLuceneDocument();
|
||||
String assignedClass;
|
||||
Object documentClass = doc.getFieldValue(classFieldName);
|
||||
Object documentClass = doc.getFieldValue(trainingClassField);
|
||||
if (documentClass == null) {
|
||||
ClassificationResult<BytesRef> classificationResult = classifier.assignClass(luceneDocument);
|
||||
if (classificationResult != null) {
|
||||
assignedClass = classificationResult.getAssignedClass().utf8ToString();
|
||||
doc.addField(classFieldName, assignedClass);
|
||||
List<ClassificationResult<BytesRef>> assignedClassifications = classifier.getClasses(luceneDocument, maxOutputClasses);
|
||||
if (assignedClassifications != null) {
|
||||
for (ClassificationResult<BytesRef> singleClassification : assignedClassifications) {
|
||||
assignedClass = singleClassification.getAssignedClass().utf8ToString();
|
||||
doc.addField(predictedClassField, assignedClass);
|
||||
}
|
||||
}
|
||||
}
|
||||
super.processAdd(cmd);
|
||||
|
|
|
@ -18,12 +18,17 @@
|
|||
package org.apache.solr.update.processor;
|
||||
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.solr.common.SolrException;
|
||||
import org.apache.solr.common.params.SolrParams;
|
||||
import org.apache.solr.common.util.NamedList;
|
||||
import org.apache.solr.request.SolrQueryRequest;
|
||||
import org.apache.solr.response.SolrQueryResponse;
|
||||
import org.apache.solr.schema.IndexSchema;
|
||||
import org.apache.solr.search.LuceneQParser;
|
||||
import org.apache.solr.search.SyntaxError;
|
||||
|
||||
import static org.apache.solr.update.processor.ClassificationUpdateProcessorFactory.Algorithm.KNN;
|
||||
|
||||
/**
|
||||
* This class implements an UpdateProcessorFactory for the Classification Update Processor.
|
||||
|
@ -33,49 +38,67 @@ public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessor
|
|||
|
||||
// Update Processor Config params
|
||||
private static final String INPUT_FIELDS_PARAM = "inputFields";
|
||||
private static final String CLASS_FIELD_PARAM = "classField";
|
||||
private static final String TRAINING_CLASS_FIELD_PARAM = "classField";
|
||||
private static final String PREDICTED_CLASS_FIELD_PARAM = "predictedClassField";
|
||||
private static final String MAX_CLASSES_TO_ASSIGN_PARAM = "predictedClass.maxCount";
|
||||
private static final String ALGORITHM_PARAM = "algorithm";
|
||||
private static final String KNN_MIN_TF_PARAM = "knn.minTf";
|
||||
private static final String KNN_MIN_DF_PARAM = "knn.minDf";
|
||||
private static final String KNN_K_PARAM = "knn.k";
|
||||
private static final String KNN_FILTER_QUERY = "knn.filterQuery";
|
||||
|
||||
public enum Algorithm {KNN, BAYES}
|
||||
|
||||
//Update Processor Defaults
|
||||
private static final int DEFAULT_MAX_CLASSES_TO_ASSIGN = 1;
|
||||
private static final int DEFAULT_MIN_TF = 1;
|
||||
private static final int DEFAULT_MIN_DF = 1;
|
||||
private static final int DEFAULT_K = 10;
|
||||
private static final String DEFAULT_ALGORITHM = "knn";
|
||||
private static final Algorithm DEFAULT_ALGORITHM = KNN;
|
||||
|
||||
private String[] inputFieldNames; // the array of fields to be sent to the Classifier
|
||||
|
||||
private String classFieldName; // the field containing the class for the Document
|
||||
|
||||
private String algorithm; // the Classification Algorithm to use - currently 'knn' or 'bayes'
|
||||
|
||||
private int minTf; // knn specific - the minimum Term Frequency for considering a term
|
||||
|
||||
private int minDf; // knn specific - the minimum Document Frequency for considering a term
|
||||
|
||||
private int k; // knn specific - thw window of top results to evaluate, when assigning the class
|
||||
private SolrParams params;
|
||||
private ClassificationUpdateProcessorParams classificationParams;
|
||||
|
||||
@Override
|
||||
public void init(final NamedList args) {
|
||||
if (args != null) {
|
||||
SolrParams params = SolrParams.toSolrParams(args);
|
||||
params = SolrParams.toSolrParams(args);
|
||||
classificationParams = new ClassificationUpdateProcessorParams();
|
||||
|
||||
String fieldNames = params.get(INPUT_FIELDS_PARAM);// must be a comma separated list of fields
|
||||
checkNotNull(INPUT_FIELDS_PARAM, fieldNames);
|
||||
inputFieldNames = fieldNames.split("\\,");
|
||||
classificationParams.setInputFieldNames(fieldNames.split("\\,"));
|
||||
|
||||
classFieldName = params.get(CLASS_FIELD_PARAM);
|
||||
checkNotNull(CLASS_FIELD_PARAM, classFieldName);
|
||||
String trainingClassField = (params.get(TRAINING_CLASS_FIELD_PARAM));
|
||||
checkNotNull(TRAINING_CLASS_FIELD_PARAM, trainingClassField);
|
||||
classificationParams.setTrainingClassField(trainingClassField);
|
||||
|
||||
algorithm = params.get(ALGORITHM_PARAM);
|
||||
if (algorithm == null)
|
||||
algorithm = DEFAULT_ALGORITHM;
|
||||
String predictedClassField = (params.get(PREDICTED_CLASS_FIELD_PARAM));
|
||||
if (predictedClassField == null || predictedClassField.isEmpty()) {
|
||||
predictedClassField = trainingClassField;
|
||||
}
|
||||
classificationParams.setPredictedClassField(predictedClassField);
|
||||
|
||||
minTf = getIntParam(params, KNN_MIN_TF_PARAM, DEFAULT_MIN_TF);
|
||||
minDf = getIntParam(params, KNN_MIN_DF_PARAM, DEFAULT_MIN_DF);
|
||||
k = getIntParam(params, KNN_K_PARAM, DEFAULT_K);
|
||||
classificationParams.setMaxPredictedClasses(getIntParam(params, MAX_CLASSES_TO_ASSIGN_PARAM, DEFAULT_MAX_CLASSES_TO_ASSIGN));
|
||||
|
||||
String algorithmString = params.get(ALGORITHM_PARAM);
|
||||
Algorithm classificationAlgorithm;
|
||||
try {
|
||||
if (algorithmString == null || Algorithm.valueOf(algorithmString.toUpperCase()) == null) {
|
||||
classificationAlgorithm = DEFAULT_ALGORITHM;
|
||||
} else {
|
||||
classificationAlgorithm = Algorithm.valueOf(algorithmString.toUpperCase());
|
||||
}
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw new SolrException
|
||||
(SolrException.ErrorCode.SERVER_ERROR,
|
||||
"Classification UpdateProcessor Algorithm: '" + algorithmString + "' not supported");
|
||||
}
|
||||
classificationParams.setAlgorithm(classificationAlgorithm);
|
||||
|
||||
classificationParams.setMinTf(getIntParam(params, KNN_MIN_TF_PARAM, DEFAULT_MIN_TF));
|
||||
classificationParams.setMinDf(getIntParam(params, KNN_MIN_DF_PARAM, DEFAULT_MIN_DF));
|
||||
classificationParams.setK(getIntParam(params, KNN_K_PARAM, DEFAULT_K));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -108,116 +131,34 @@ public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessor
|
|||
|
||||
@Override
|
||||
public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
|
||||
String trainingFilterQueryString = (params.get(KNN_FILTER_QUERY));
|
||||
try {
|
||||
if (trainingFilterQueryString != null && !trainingFilterQueryString.isEmpty()) {
|
||||
Query trainingFilterQuery = this.parseFilterQuery(trainingFilterQueryString, params, req);
|
||||
classificationParams.setTrainingFilterQuery(trainingFilterQuery);
|
||||
}
|
||||
} catch (SyntaxError | RuntimeException syntaxError) {
|
||||
throw new SolrException
|
||||
(SolrException.ErrorCode.SERVER_ERROR,
|
||||
"Classification UpdateProcessor Training Filter Query: '" + trainingFilterQueryString + "' is not supported", syntaxError);
|
||||
}
|
||||
|
||||
IndexSchema schema = req.getSchema();
|
||||
IndexReader indexReader = req.getSearcher().getIndexReader();
|
||||
return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, indexReader, schema);
|
||||
|
||||
return new ClassificationUpdateProcessor(classificationParams, next, indexReader, schema);
|
||||
}
|
||||
|
||||
/**
|
||||
* get field names used as classifier's inputs
|
||||
*
|
||||
* @return the input field names
|
||||
*/
|
||||
public String[] getInputFieldNames() {
|
||||
return inputFieldNames;
|
||||
private Query parseFilterQuery(String trainingFilterQueryString, SolrParams params, SolrQueryRequest req) throws SyntaxError {
|
||||
LuceneQParser parser = new LuceneQParser(trainingFilterQueryString, null, params, req);
|
||||
return parser.parse();
|
||||
}
|
||||
|
||||
/**
|
||||
* set field names used as classifier's inputs
|
||||
*
|
||||
* @param inputFieldNames the input field names
|
||||
*/
|
||||
public void setInputFieldNames(String[] inputFieldNames) {
|
||||
this.inputFieldNames = inputFieldNames;
|
||||
public ClassificationUpdateProcessorParams getClassificationParams() {
|
||||
return classificationParams;
|
||||
}
|
||||
|
||||
/**
|
||||
* get field names used as classifier's output
|
||||
*
|
||||
* @return the output field name
|
||||
*/
|
||||
public String getClassFieldName() {
|
||||
return classFieldName;
|
||||
}
|
||||
|
||||
/**
|
||||
* set field names used as classifier's output
|
||||
*
|
||||
* @param classFieldName the output field name
|
||||
*/
|
||||
public void setClassFieldName(String classFieldName) {
|
||||
this.classFieldName = classFieldName;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the name of the classifier algorithm used
|
||||
*
|
||||
* @return the classifier algorithm used
|
||||
*/
|
||||
public String getAlgorithm() {
|
||||
return algorithm;
|
||||
}
|
||||
|
||||
/**
|
||||
* set the name of the classifier algorithm used
|
||||
*
|
||||
* @param algorithm the classifier algorithm used
|
||||
*/
|
||||
public void setAlgorithm(String algorithm) {
|
||||
this.algorithm = algorithm;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the min term frequency value to be used in case algorithm is {@code "knn"}
|
||||
*
|
||||
* @return the min term frequency
|
||||
*/
|
||||
public int getMinTf() {
|
||||
return minTf;
|
||||
}
|
||||
|
||||
/**
|
||||
* set the min term frequency value to be used in case algorithm is {@code "knn"}
|
||||
*
|
||||
* @param minTf the min term frequency
|
||||
*/
|
||||
public void setMinTf(int minTf) {
|
||||
this.minTf = minTf;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the min document frequency value to be used in case algorithm is {@code "knn"}
|
||||
*
|
||||
* @return the min document frequency
|
||||
*/
|
||||
public int getMinDf() {
|
||||
return minDf;
|
||||
}
|
||||
|
||||
/**
|
||||
* set the min document frequency value to be used in case algorithm is {@code "knn"}
|
||||
*
|
||||
* @param minDf the min document frequency
|
||||
*/
|
||||
public void setMinDf(int minDf) {
|
||||
this.minDf = minDf;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"}
|
||||
*
|
||||
* @return the no. of neighbors to analyze
|
||||
*/
|
||||
public int getK() {
|
||||
return k;
|
||||
}
|
||||
|
||||
/**
|
||||
* set the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"}
|
||||
*
|
||||
* @param k the no. of neighbors to analyze
|
||||
*/
|
||||
public void setK(int k) {
|
||||
this.k = k;
|
||||
public void setClassificationParams(ClassificationUpdateProcessorParams classificationParams) {
|
||||
this.classificationParams = classificationParams;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.update.processor;
|
||||
|
||||
import org.apache.lucene.search.Query;
|
||||
|
||||
public class ClassificationUpdateProcessorParams {
|
||||
|
||||
private String[] inputFieldNames; // the array of fields to be sent to the Classifier
|
||||
|
||||
private Query trainingFilterQuery; // a filter query to reduce the training set to a subset
|
||||
|
||||
private String trainingClassField; // the field containing the class for the Document
|
||||
|
||||
private String predictedClassField; // the field that will contain the predicted class
|
||||
|
||||
private int maxPredictedClasses; // the max number of classes to assign
|
||||
|
||||
private ClassificationUpdateProcessorFactory.Algorithm algorithm; // the Classification Algorithm to use - currently 'knn' or 'bayes'
|
||||
|
||||
private int minTf; // knn specific - the minimum Term Frequency for considering a term
|
||||
|
||||
private int minDf; // knn specific - the minimum Document Frequency for considering a term
|
||||
|
||||
private int k; // knn specific - thw window of top results to evaluate, when assigning the class
|
||||
|
||||
public String[] getInputFieldNames() {
|
||||
return inputFieldNames;
|
||||
}
|
||||
|
||||
public void setInputFieldNames(String[] inputFieldNames) {
|
||||
this.inputFieldNames = inputFieldNames;
|
||||
}
|
||||
|
||||
public Query getTrainingFilterQuery() {
|
||||
return trainingFilterQuery;
|
||||
}
|
||||
|
||||
public void setTrainingFilterQuery(Query trainingFilterQuery) {
|
||||
this.trainingFilterQuery = trainingFilterQuery;
|
||||
}
|
||||
|
||||
public String getTrainingClassField() {
|
||||
return trainingClassField;
|
||||
}
|
||||
|
||||
public void setTrainingClassField(String trainingClassField) {
|
||||
this.trainingClassField = trainingClassField;
|
||||
}
|
||||
|
||||
public String getPredictedClassField() {
|
||||
return predictedClassField;
|
||||
}
|
||||
|
||||
public void setPredictedClassField(String predictedClassField) {
|
||||
this.predictedClassField = predictedClassField;
|
||||
}
|
||||
|
||||
public int getMaxPredictedClasses() {
|
||||
return maxPredictedClasses;
|
||||
}
|
||||
|
||||
public void setMaxPredictedClasses(int maxPredictedClasses) {
|
||||
this.maxPredictedClasses = maxPredictedClasses;
|
||||
}
|
||||
|
||||
public ClassificationUpdateProcessorFactory.Algorithm getAlgorithm() {
|
||||
return algorithm;
|
||||
}
|
||||
|
||||
public void setAlgorithm(ClassificationUpdateProcessorFactory.Algorithm algorithm) {
|
||||
this.algorithm = algorithm;
|
||||
}
|
||||
|
||||
public int getMinTf() {
|
||||
return minTf;
|
||||
}
|
||||
|
||||
public void setMinTf(int minTf) {
|
||||
this.minTf = minTf;
|
||||
}
|
||||
|
||||
public int getMinDf() {
|
||||
return minDf;
|
||||
}
|
||||
|
||||
public void setMinDf(int minDf) {
|
||||
this.minDf = minDf;
|
||||
}
|
||||
|
||||
public int getK() {
|
||||
return k;
|
||||
}
|
||||
|
||||
public void setK(int k) {
|
||||
this.k = k;
|
||||
}
|
||||
}
|
|
@ -47,6 +47,21 @@
|
|||
<str name="knn.minTf">1</str>
|
||||
<str name="knn.minDf">1</str>
|
||||
<str name="knn.k">5</str>
|
||||
<str name="knn.filterQuery">cat:(class1 OR class2)</str>
|
||||
</processor>
|
||||
<processor class="solr.RunUpdateProcessorFactory"/>
|
||||
</updateRequestProcessorChain>
|
||||
|
||||
<updateRequestProcessorChain name="classification-unsupported-filterQuery">
|
||||
<processor class="solr.ClassificationUpdateProcessorFactory">
|
||||
<str name="inputFields">title,content,author</str>
|
||||
<str name="classField">cat</str>
|
||||
<!-- Knn algorithm specific-->
|
||||
<str name="algorithm">knn</str>
|
||||
<str name="knn.minTf">1</str>
|
||||
<str name="knn.minDf">1</str>
|
||||
<str name="knn.k">5</str>
|
||||
<str name="knn.filterQuery">not valid ( lucene query</str>
|
||||
</processor>
|
||||
<processor class="solr.RunUpdateProcessorFactory"/>
|
||||
</updateRequestProcessorChain>
|
||||
|
|
|
@ -14,71 +14,31 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.update.processor;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.solr.SolrTestCaseJ4;
|
||||
import org.apache.solr.common.SolrException;
|
||||
import org.apache.solr.common.params.MultiMapSolrParams;
|
||||
import org.apache.solr.common.params.SolrParams;
|
||||
import org.apache.solr.common.params.UpdateParams;
|
||||
import org.apache.solr.common.util.ContentStream;
|
||||
import org.apache.solr.common.util.ContentStreamBase;
|
||||
import org.apache.solr.common.util.NamedList;
|
||||
import org.apache.solr.handler.UpdateRequestHandler;
|
||||
import org.apache.solr.request.SolrQueryRequest;
|
||||
import org.apache.solr.request.SolrQueryRequestBase;
|
||||
import org.apache.solr.response.SolrQueryResponse;
|
||||
import org.apache.solr.search.SolrIndexSearcher;
|
||||
import org.junit.Before;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.hamcrest.core.Is.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
/**
|
||||
* Tests for {@link ClassificationUpdateProcessor} and {@link ClassificationUpdateProcessorFactory}
|
||||
* Tests for {@link ClassificationUpdateProcessorFactory}
|
||||
*/
|
||||
public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
|
||||
// field names are used in accordance with the solrconfig and schema supplied
|
||||
private static final String ID = "id";
|
||||
private static final String TITLE = "title";
|
||||
private static final String CONTENT = "content";
|
||||
private static final String AUTHOR = "author";
|
||||
private static final String CLASS = "cat";
|
||||
|
||||
private static final String CHAIN = "classification";
|
||||
|
||||
|
||||
private ClassificationUpdateProcessorFactory cFactoryToTest = new ClassificationUpdateProcessorFactory();
|
||||
private NamedList args = new NamedList<String>();
|
||||
|
||||
@BeforeClass
|
||||
public static void beforeClass() throws Exception {
|
||||
System.setProperty("enable.update.log", "false");
|
||||
initCore("solrconfig-classification.xml", "schema-classification.xml");
|
||||
}
|
||||
|
||||
@Override
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
clearIndex();
|
||||
assertU(commit());
|
||||
}
|
||||
|
||||
@Before
|
||||
public void initArgs() {
|
||||
args.add("inputFields", "inputField1,inputField2");
|
||||
args.add("classField", "classField1");
|
||||
args.add("predictedClassField", "classFieldX");
|
||||
args.add("algorithm", "bayes");
|
||||
args.add("knn.k", "9");
|
||||
args.add("knn.minDf", "8");
|
||||
|
@ -86,22 +46,23 @@ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testFullInit() {
|
||||
public void init_fullArgs_shouldInitFullClassificationParams() {
|
||||
cFactoryToTest.init(args);
|
||||
ClassificationUpdateProcessorParams classificationParams = cFactoryToTest.getClassificationParams();
|
||||
|
||||
String[] inputFieldNames = cFactoryToTest.getInputFieldNames();
|
||||
String[] inputFieldNames = classificationParams.getInputFieldNames();
|
||||
assertEquals("inputField1", inputFieldNames[0]);
|
||||
assertEquals("inputField2", inputFieldNames[1]);
|
||||
assertEquals("classField1", cFactoryToTest.getClassFieldName());
|
||||
assertEquals("bayes", cFactoryToTest.getAlgorithm());
|
||||
assertEquals(8, cFactoryToTest.getMinDf());
|
||||
assertEquals(10, cFactoryToTest.getMinTf());
|
||||
assertEquals(9, cFactoryToTest.getK());
|
||||
|
||||
assertEquals("classField1", classificationParams.getTrainingClassField());
|
||||
assertEquals("classFieldX", classificationParams.getPredictedClassField());
|
||||
assertEquals(ClassificationUpdateProcessorFactory.Algorithm.BAYES, classificationParams.getAlgorithm());
|
||||
assertEquals(8, classificationParams.getMinDf());
|
||||
assertEquals(10, classificationParams.getMinTf());
|
||||
assertEquals(9, classificationParams.getK());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInitEmptyInputField() {
|
||||
public void init_emptyInputFields_shouldThrowExceptionWithDetailedMessage() {
|
||||
args.removeAll("inputFields");
|
||||
try {
|
||||
cFactoryToTest.init(args);
|
||||
|
@ -111,7 +72,7 @@ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testInitEmptyClassField() {
|
||||
public void init_emptyClassField_shouldThrowExceptionWithDetailedMessage() {
|
||||
args.removeAll("classField");
|
||||
try {
|
||||
cFactoryToTest.init(args);
|
||||
|
@ -121,114 +82,53 @@ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testDefaults() {
|
||||
public void init_emptyPredictedClassField_shouldDefaultToTrainingClassField() {
|
||||
args.removeAll("predictedClassField");
|
||||
|
||||
cFactoryToTest.init(args);
|
||||
|
||||
ClassificationUpdateProcessorParams classificationParams = cFactoryToTest.getClassificationParams();
|
||||
assertThat(classificationParams.getPredictedClassField(), is("classField1"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void init_unsupportedAlgorithm_shouldThrowExceptionWithDetailedMessage() {
|
||||
args.removeAll("algorithm");
|
||||
args.add("algorithm", "unsupported");
|
||||
try {
|
||||
cFactoryToTest.init(args);
|
||||
} catch (SolrException e) {
|
||||
assertEquals("Classification UpdateProcessor Algorithm: 'unsupported' not supported", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void init_unsupportedFilterQuery_shouldThrowExceptionWithDetailedMessage() {
|
||||
UpdateRequestProcessor mockProcessor = mock(UpdateRequestProcessor.class);
|
||||
SolrQueryRequest mockRequest = mock(SolrQueryRequest.class);
|
||||
SolrQueryResponse mockResponse = mock(SolrQueryResponse.class);
|
||||
args.add("knn.filterQuery", "not supported query");
|
||||
try {
|
||||
cFactoryToTest.init(args);
|
||||
/* parsing failure happens because of the mocks, fine enough to check a proper exception propagation */
|
||||
cFactoryToTest.getInstance(mockRequest, mockResponse, mockProcessor);
|
||||
} catch (SolrException e) {
|
||||
assertEquals("Classification UpdateProcessor Training Filter Query: 'not supported query' is not supported", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void init_emptyArgs_shouldDefaultClassificationParams() {
|
||||
args.removeAll("algorithm");
|
||||
args.removeAll("knn.k");
|
||||
args.removeAll("knn.minDf");
|
||||
args.removeAll("knn.minTf");
|
||||
cFactoryToTest.init(args);
|
||||
assertEquals("knn", cFactoryToTest.getAlgorithm());
|
||||
assertEquals(1, cFactoryToTest.getMinDf());
|
||||
assertEquals(1, cFactoryToTest.getMinTf());
|
||||
assertEquals(10, cFactoryToTest.getK());
|
||||
}
|
||||
ClassificationUpdateProcessorParams classificationParams = cFactoryToTest.getClassificationParams();
|
||||
|
||||
@Test
|
||||
public void testBasicClassification() throws Exception {
|
||||
prepareTrainedIndex();
|
||||
// To be classified,we index documents without a class and verify the expected one is returned
|
||||
addDoc(adoc(ID, "10",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5 ",
|
||||
AUTHOR, "Name1 Surname1"));
|
||||
addDoc(adoc(ID, "11",
|
||||
TITLE, "word1 word1",
|
||||
CONTENT, "word2 word2",
|
||||
AUTHOR, "Name Surname"));
|
||||
addDoc(commit());
|
||||
|
||||
Document doc10 = getDoc("10");
|
||||
assertEquals("class2", doc10.get(CLASS));
|
||||
Document doc11 = getDoc("11");
|
||||
assertEquals("class1", doc11.get(CLASS));
|
||||
}
|
||||
|
||||
/**
|
||||
* Index some example documents with a class manually assigned.
|
||||
* This will be our trained model.
|
||||
*
|
||||
* @throws Exception If there is a low-level I/O error
|
||||
*/
|
||||
private void prepareTrainedIndex() throws Exception {
|
||||
//class1
|
||||
addDoc(adoc(ID, "1",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2 word2 word2",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class1"));
|
||||
addDoc(adoc(ID, "2",
|
||||
TITLE, "word1 word1",
|
||||
CONTENT, "word2 word2",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class1"));
|
||||
addDoc(adoc(ID, "3",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class1"));
|
||||
addDoc(adoc(ID, "4",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2 word2 word2",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class1"));
|
||||
//class2
|
||||
addDoc(adoc(ID, "5",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
CLASS, "class2"));
|
||||
addDoc(adoc(ID, "6",
|
||||
TITLE, "word4 word4",
|
||||
CONTENT, "word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
CLASS, "class2"));
|
||||
addDoc(adoc(ID, "7",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5 word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
CLASS, "class2"));
|
||||
addDoc(adoc(ID, "8",
|
||||
TITLE, "word4",
|
||||
CONTENT, "word5 word5 word5 word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
CLASS, "class2"));
|
||||
addDoc(commit());
|
||||
}
|
||||
|
||||
private Document getDoc(String id) throws IOException {
|
||||
try (SolrQueryRequest req = req()) {
|
||||
SolrIndexSearcher searcher = req.getSearcher();
|
||||
TermQuery query = new TermQuery(new Term(ID, id));
|
||||
TopDocs doc1 = searcher.search(query, 1);
|
||||
ScoreDoc scoreDoc = doc1.scoreDocs[0];
|
||||
return searcher.doc(scoreDoc.doc);
|
||||
}
|
||||
}
|
||||
|
||||
static void addDoc(String doc) throws Exception {
|
||||
Map<String, String[]> params = new HashMap<>();
|
||||
MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
|
||||
params.put(UpdateParams.UPDATE_CHAIN, new String[]{CHAIN});
|
||||
SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(),
|
||||
(SolrParams) mmparams) {
|
||||
};
|
||||
|
||||
UpdateRequestHandler handler = new UpdateRequestHandler();
|
||||
handler.init(null);
|
||||
ArrayList<ContentStream> streams = new ArrayList<>(2);
|
||||
streams.add(new ContentStreamBase.StringStream(doc));
|
||||
req.setContentStreams(streams);
|
||||
handler.handleRequestBody(req, new SolrQueryResponse());
|
||||
req.close();
|
||||
assertEquals(ClassificationUpdateProcessorFactory.Algorithm.KNN, classificationParams.getAlgorithm());
|
||||
assertEquals(1, classificationParams.getMinDf());
|
||||
assertEquals(1, classificationParams.getMinTf());
|
||||
assertEquals(10, classificationParams.getK());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,192 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.update.processor;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.solr.SolrTestCaseJ4;
|
||||
import org.apache.solr.common.SolrException;
|
||||
import org.apache.solr.common.util.NamedList;
|
||||
import org.apache.solr.request.SolrQueryRequest;
|
||||
import org.apache.solr.search.SolrIndexSearcher;
|
||||
import org.junit.Before;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.hamcrest.core.Is.is;
|
||||
|
||||
/**
|
||||
* Tests for {@link ClassificationUpdateProcessor} and {@link ClassificationUpdateProcessorFactory}
|
||||
*/
|
||||
public class ClassificationUpdateProcessorIntegrationTest extends SolrTestCaseJ4 {
|
||||
/* field names are used in accordance with the solrconfig and schema supplied */
|
||||
private static final String ID = "id";
|
||||
private static final String TITLE = "title";
|
||||
private static final String CONTENT = "content";
|
||||
private static final String AUTHOR = "author";
|
||||
private static final String CLASS = "cat";
|
||||
|
||||
private static final String CHAIN = "classification";
|
||||
private static final String BROKEN_CHAIN_FILTER_QUERY = "classification-unsupported-filterQuery";
|
||||
|
||||
private ClassificationUpdateProcessorFactory cFactoryToTest = new ClassificationUpdateProcessorFactory();
|
||||
private NamedList args = new NamedList<String>();
|
||||
|
||||
@BeforeClass
|
||||
public static void beforeClass() throws Exception {
|
||||
System.setProperty("enable.update.log", "false");
|
||||
initCore("solrconfig-classification.xml", "schema-classification.xml");
|
||||
}
|
||||
|
||||
@Override
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
clearIndex();
|
||||
assertU(commit());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void classify_fullConfiguration_shouldAutoClassify() throws Exception {
|
||||
indexTrainingSet();
|
||||
// To be classified,we index documents without a class and verify the expected one is returned
|
||||
addDoc(adoc(ID, "22",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5 ",
|
||||
AUTHOR, "Name1 Surname1"), CHAIN);
|
||||
addDoc(adoc(ID, "21",
|
||||
TITLE, "word1 word1",
|
||||
CONTENT, "word2 word2",
|
||||
AUTHOR, "Name Surname"), CHAIN);
|
||||
addDoc(commit());
|
||||
|
||||
Document doc22 = getDoc("22");
|
||||
assertThat(doc22.get(CLASS),is("class2"));
|
||||
Document doc21 = getDoc("21");
|
||||
assertThat(doc21.get(CLASS),is("class1"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void classify_unsupportedFilterQueryConfiguration_shouldThrowExceptionWithDetailedMessage() throws Exception {
|
||||
indexTrainingSet();
|
||||
try {
|
||||
addDoc(adoc(ID, "21",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5 ",
|
||||
AUTHOR, "Name1 Surname1"), BROKEN_CHAIN_FILTER_QUERY);
|
||||
addDoc(adoc(ID, "22",
|
||||
TITLE, "word1 word1",
|
||||
CONTENT, "word2 word2",
|
||||
AUTHOR, "Name Surname"), BROKEN_CHAIN_FILTER_QUERY);
|
||||
addDoc(commit());
|
||||
} catch (SolrException e) {
|
||||
assertEquals("Classification UpdateProcessor Training Filter Query: 'not valid ( lucene query' is not supported", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Index some example documents with a class manually assigned.
|
||||
* This will be our trained model.
|
||||
*
|
||||
* @throws Exception If there is a low-level I/O error
|
||||
*/
|
||||
private void indexTrainingSet() throws Exception {
|
||||
//class1
|
||||
addDoc(adoc(ID, "1",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2 word2 word2",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class1"), CHAIN);
|
||||
addDoc(adoc(ID, "2",
|
||||
TITLE, "word1 word1",
|
||||
CONTENT, "word2 word2",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class1"), CHAIN);
|
||||
addDoc(adoc(ID, "3",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class1"), CHAIN);
|
||||
addDoc(adoc(ID, "4",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2 word2 word2",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class1"), CHAIN);
|
||||
//class2
|
||||
addDoc(adoc(ID, "5",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class2"), CHAIN);
|
||||
addDoc(adoc(ID, "6",
|
||||
TITLE, "word4 word4",
|
||||
CONTENT, "word5",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class2"), CHAIN);
|
||||
addDoc(adoc(ID, "7",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5 word5",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class2"), CHAIN);
|
||||
addDoc(adoc(ID, "8",
|
||||
TITLE, "word4",
|
||||
CONTENT, "word5 word5 word5 word5",
|
||||
AUTHOR, "Name Surname",
|
||||
CLASS, "class2"), CHAIN);
|
||||
//class3
|
||||
addDoc(adoc(ID, "9",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
CLASS, "class3"), CHAIN);
|
||||
addDoc(adoc(ID, "10",
|
||||
TITLE, "word4 word4",
|
||||
CONTENT, "word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
CLASS, "class3"), CHAIN);
|
||||
addDoc(adoc(ID, "11",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5 word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
CLASS, "class3"), CHAIN);
|
||||
addDoc(adoc(ID, "12",
|
||||
TITLE, "word4",
|
||||
CONTENT, "word5 word5 word5 word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
CLASS, "class3"), CHAIN);
|
||||
addDoc(commit());
|
||||
}
|
||||
|
||||
private Document getDoc(String id) throws IOException {
|
||||
try (SolrQueryRequest req = req()) {
|
||||
SolrIndexSearcher searcher = req.getSearcher();
|
||||
TermQuery query = new TermQuery(new Term(ID, id));
|
||||
TopDocs doc1 = searcher.search(query, 1);
|
||||
ScoreDoc scoreDoc = doc1.scoreDocs[0];
|
||||
return searcher.doc(scoreDoc.doc);
|
||||
}
|
||||
}
|
||||
|
||||
private void addDoc(String doc) throws Exception {
|
||||
addDoc(doc, CHAIN);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,507 @@
|
|||
package org.apache.solr.update.processor;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.analysis.MockTokenizer;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.solr.SolrTestCaseJ4;
|
||||
import org.apache.solr.common.SolrInputDocument;
|
||||
import org.apache.solr.update.AddUpdateCommand;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.hamcrest.core.Is.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Tests for {@link ClassificationUpdateProcessor}
|
||||
*/
|
||||
public class ClassificationUpdateProcessorTest extends SolrTestCaseJ4 {
|
||||
/* field names are used in accordance with the solrconfig and schema supplied */
|
||||
private static final String ID = "id";
|
||||
private static final String TITLE = "title";
|
||||
private static final String CONTENT = "content";
|
||||
private static final String AUTHOR = "author";
|
||||
private static final String TRAINING_CLASS = "cat";
|
||||
private static final String PREDICTED_CLASS = "predicted";
|
||||
public static final String KNN = "knn";
|
||||
|
||||
protected Directory directory;
|
||||
protected IndexReader reader;
|
||||
protected IndexSearcher searcher;
|
||||
protected Analyzer analyzer = new MockAnalyzer(random(), MockTokenizer.WHITESPACE, false);
|
||||
private ClassificationUpdateProcessor updateProcessorToTest;
|
||||
|
||||
@BeforeClass
|
||||
public static void beforeClass() throws Exception {
|
||||
System.setProperty("enable.update.log", "false");
|
||||
initCore("solrconfig-classification.xml", "schema-classification.xml");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void tearDown() throws Exception {
|
||||
reader.close();
|
||||
directory.close();
|
||||
analyzer.close();
|
||||
super.tearDown();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void classificationMonoClass_predictedClassFieldSet_shouldAssignClassInPredictedClassField() throws Exception {
|
||||
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
|
||||
prepareTrainedIndexMonoClass();
|
||||
|
||||
AddUpdateCommand update=new AddUpdateCommand(req());
|
||||
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word2 word2 ",
|
||||
AUTHOR, "unseenAuthor");
|
||||
update.solrDoc=unseenDocument1;
|
||||
|
||||
ClassificationUpdateProcessorParams params = initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
|
||||
params.setPredictedClassField(PREDICTED_CLASS);
|
||||
|
||||
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
|
||||
updateProcessorToTest.processAdd(update);
|
||||
|
||||
assertThat(unseenDocument1.getFieldValue(PREDICTED_CLASS),is("class1"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void knnMonoClass_sampleParams_shouldAssignCorrectClass() throws Exception {
|
||||
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
|
||||
prepareTrainedIndexMonoClass();
|
||||
|
||||
AddUpdateCommand update=new AddUpdateCommand(req());
|
||||
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word2 word2 ",
|
||||
AUTHOR, "unseenAuthor");
|
||||
update.solrDoc=unseenDocument1;
|
||||
|
||||
ClassificationUpdateProcessorParams params = initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
|
||||
|
||||
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
|
||||
updateProcessorToTest.processAdd(update);
|
||||
|
||||
assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class1"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void knnMonoClass_boostFields_shouldAssignCorrectClass() throws Exception {
|
||||
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
|
||||
prepareTrainedIndexMonoClass();
|
||||
|
||||
AddUpdateCommand update=new AddUpdateCommand(req());
|
||||
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word2 word2 ",
|
||||
AUTHOR, "unseenAuthor");
|
||||
update.solrDoc=unseenDocument1;
|
||||
|
||||
ClassificationUpdateProcessorParams params = initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
|
||||
params.setInputFieldNames(new String[]{TITLE + "^1.5", CONTENT + "^0.5", AUTHOR + "^2.5"});
|
||||
|
||||
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
|
||||
|
||||
updateProcessorToTest.processAdd(update);
|
||||
|
||||
assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class2"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void bayesMonoClass_sampleParams_shouldAssignCorrectClass() throws Exception {
|
||||
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
|
||||
prepareTrainedIndexMonoClass();
|
||||
|
||||
AddUpdateCommand update=new AddUpdateCommand(req());
|
||||
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word2 word2 ",
|
||||
AUTHOR, "unseenAuthor");
|
||||
update.solrDoc=unseenDocument1;
|
||||
|
||||
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
|
||||
|
||||
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
|
||||
updateProcessorToTest.processAdd(update);
|
||||
|
||||
assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class1"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void knnMonoClass_contextQueryFiltered_shouldAssignCorrectClass() throws Exception {
|
||||
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
|
||||
prepareTrainedIndexMonoClass();
|
||||
|
||||
AddUpdateCommand update=new AddUpdateCommand(req());
|
||||
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word2 word2 ",
|
||||
AUTHOR, "a");
|
||||
update.solrDoc=unseenDocument1;
|
||||
|
||||
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
|
||||
Query class3DocsChunk=new TermQuery(new Term(TITLE,"word6"));
|
||||
params.setTrainingFilterQuery(class3DocsChunk);
|
||||
|
||||
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
|
||||
updateProcessorToTest.processAdd(update);
|
||||
|
||||
assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class3"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void bayesMonoClass_boostFields_shouldAssignCorrectClass() throws Exception {
|
||||
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
|
||||
prepareTrainedIndexMonoClass();
|
||||
|
||||
AddUpdateCommand update=new AddUpdateCommand(req());
|
||||
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word2 word2 ",
|
||||
AUTHOR, "unseenAuthor");
|
||||
update.solrDoc=unseenDocument1;
|
||||
|
||||
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
|
||||
params.setInputFieldNames(new String[]{TITLE+"^1.5",CONTENT+"^0.5",AUTHOR+"^2.5"});
|
||||
|
||||
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
|
||||
|
||||
updateProcessorToTest.processAdd(update);
|
||||
|
||||
assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class2"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void knnClassification_maxOutputClassesGreaterThanAvailable_shouldAssignCorrectClass() throws Exception {
|
||||
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
|
||||
prepareTrainedIndexMultiClass();
|
||||
|
||||
AddUpdateCommand update=new AddUpdateCommand(req());
|
||||
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2 word2 ",
|
||||
AUTHOR, "unseenAuthor");
|
||||
update.solrDoc=unseenDocument1;
|
||||
|
||||
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
|
||||
params.setMaxPredictedClasses(100);
|
||||
|
||||
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
|
||||
updateProcessorToTest.processAdd(update);
|
||||
|
||||
ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
|
||||
assertThat(assignedClasses.get(0),is("class2"));
|
||||
assertThat(assignedClasses.get(1),is("class1"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void knnMultiClass_maxOutputClasses2_shouldAssignMax2Classes() throws Exception {
|
||||
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
|
||||
prepareTrainedIndexMultiClass();
|
||||
|
||||
AddUpdateCommand update=new AddUpdateCommand(req());
|
||||
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2 word2 ",
|
||||
AUTHOR, "unseenAuthor");
|
||||
update.solrDoc=unseenDocument1;
|
||||
|
||||
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
|
||||
params.setMaxPredictedClasses(2);
|
||||
|
||||
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
|
||||
updateProcessorToTest.processAdd(update);
|
||||
|
||||
ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
|
||||
assertThat(assignedClasses.size(),is(2));
|
||||
assertThat(assignedClasses.get(0),is("class2"));
|
||||
assertThat(assignedClasses.get(1),is("class1"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void bayesMultiClass_maxOutputClasses2_shouldAssignMax2Classes() throws Exception {
|
||||
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
|
||||
prepareTrainedIndexMultiClass();
|
||||
|
||||
AddUpdateCommand update=new AddUpdateCommand(req());
|
||||
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2 word2 ",
|
||||
AUTHOR, "unseenAuthor");
|
||||
update.solrDoc=unseenDocument1;
|
||||
|
||||
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
|
||||
params.setMaxPredictedClasses(2);
|
||||
|
||||
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
|
||||
updateProcessorToTest.processAdd(update);
|
||||
|
||||
ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
|
||||
assertThat(assignedClasses.size(),is(2));
|
||||
assertThat(assignedClasses.get(0),is("class2"));
|
||||
assertThat(assignedClasses.get(1),is("class1"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void knnMultiClass_boostFieldsMaxOutputClasses2_shouldAssignMax2Classes() throws Exception {
|
||||
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
|
||||
prepareTrainedIndexMultiClass();
|
||||
|
||||
AddUpdateCommand update=new AddUpdateCommand(req());
|
||||
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word2 word2 ",
|
||||
AUTHOR, "unseenAuthor");
|
||||
update.solrDoc=unseenDocument1;
|
||||
|
||||
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
|
||||
params.setInputFieldNames(new String[]{TITLE+"^1.5",CONTENT+"^0.5",AUTHOR+"^2.5"});
|
||||
params.setMaxPredictedClasses(2);
|
||||
|
||||
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
|
||||
|
||||
updateProcessorToTest.processAdd(update);
|
||||
|
||||
ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
|
||||
assertThat(assignedClasses.size(),is(2));
|
||||
assertThat(assignedClasses.get(0),is("class4"));
|
||||
assertThat(assignedClasses.get(1),is("class6"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void bayesMultiClass_boostFieldsMaxOutputClasses2_shouldAssignMax2Classes() throws Exception {
|
||||
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
|
||||
prepareTrainedIndexMultiClass();
|
||||
|
||||
AddUpdateCommand update=new AddUpdateCommand(req());
|
||||
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word2 word2 ",
|
||||
AUTHOR, "unseenAuthor");
|
||||
update.solrDoc=unseenDocument1;
|
||||
|
||||
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
|
||||
params.setInputFieldNames(new String[]{TITLE+"^1.5",CONTENT+"^0.5",AUTHOR+"^2.5"});
|
||||
params.setMaxPredictedClasses(2);
|
||||
|
||||
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
|
||||
|
||||
updateProcessorToTest.processAdd(update);
|
||||
|
||||
ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
|
||||
assertThat(assignedClasses.size(),is(2));
|
||||
assertThat(assignedClasses.get(0),is("class4"));
|
||||
assertThat(assignedClasses.get(1),is("class6"));
|
||||
}
|
||||
|
||||
private ClassificationUpdateProcessorParams initParams(ClassificationUpdateProcessorFactory.Algorithm classificationAlgorithm) {
|
||||
ClassificationUpdateProcessorParams params= new ClassificationUpdateProcessorParams();
|
||||
params.setInputFieldNames(new String[]{TITLE,CONTENT,AUTHOR});
|
||||
params.setTrainingClassField(TRAINING_CLASS);
|
||||
params.setPredictedClassField(TRAINING_CLASS);
|
||||
params.setMinTf(1);
|
||||
params.setMinDf(1);
|
||||
params.setK(5);
|
||||
params.setAlgorithm(classificationAlgorithm);
|
||||
params.setMaxPredictedClasses(1);
|
||||
return params;
|
||||
}
|
||||
|
||||
/**
|
||||
* Index some example documents with a class manually assigned.
|
||||
* This will be our trained model.
|
||||
*
|
||||
* @throws Exception If there is a low-level I/O error
|
||||
*/
|
||||
private void prepareTrainedIndexMonoClass() throws Exception {
|
||||
directory = newDirectory();
|
||||
RandomIndexWriter writer = new RandomIndexWriter(random(), directory);
|
||||
|
||||
//class1
|
||||
addDoc(writer, buildLuceneDocument(ID, "1",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2 word2 word2",
|
||||
AUTHOR, "a",
|
||||
TRAINING_CLASS, "class1"));
|
||||
addDoc(writer, buildLuceneDocument(ID, "2",
|
||||
TITLE, "word1 word1",
|
||||
CONTENT, "word2 word2",
|
||||
AUTHOR, "a",
|
||||
TRAINING_CLASS, "class1"));
|
||||
addDoc(writer, buildLuceneDocument(ID, "3",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2",
|
||||
AUTHOR, "a",
|
||||
TRAINING_CLASS, "class1"));
|
||||
addDoc(writer, buildLuceneDocument(ID, "4",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2 word2 word2",
|
||||
AUTHOR, "a",
|
||||
TRAINING_CLASS, "class1"));
|
||||
//class2
|
||||
addDoc(writer, buildLuceneDocument(ID, "5",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5",
|
||||
AUTHOR, "c",
|
||||
TRAINING_CLASS, "class2"));
|
||||
addDoc(writer, buildLuceneDocument(ID, "6",
|
||||
TITLE, "word4 word4",
|
||||
CONTENT, "word5",
|
||||
AUTHOR, "c",
|
||||
TRAINING_CLASS, "class2"));
|
||||
addDoc(writer, buildLuceneDocument(ID, "7",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5 word5",
|
||||
AUTHOR, "c",
|
||||
TRAINING_CLASS, "class2"));
|
||||
addDoc(writer, buildLuceneDocument(ID, "8",
|
||||
TITLE, "word4",
|
||||
CONTENT, "word5 word5 word5 word5",
|
||||
AUTHOR, "c",
|
||||
TRAINING_CLASS, "class2"));
|
||||
//class3
|
||||
addDoc(writer, buildLuceneDocument(ID, "9",
|
||||
TITLE, "word6",
|
||||
CONTENT, "word7",
|
||||
AUTHOR, "a",
|
||||
TRAINING_CLASS, "class3"));
|
||||
addDoc(writer, buildLuceneDocument(ID, "10",
|
||||
TITLE, "word6",
|
||||
CONTENT, "word7",
|
||||
AUTHOR, "a",
|
||||
TRAINING_CLASS, "class3"));
|
||||
addDoc(writer, buildLuceneDocument(ID, "11",
|
||||
TITLE, "word6",
|
||||
CONTENT, "word7",
|
||||
AUTHOR, "a",
|
||||
TRAINING_CLASS, "class3"));
|
||||
addDoc(writer, buildLuceneDocument(ID, "12",
|
||||
TITLE, "word6",
|
||||
CONTENT, "word7",
|
||||
AUTHOR, "a",
|
||||
TRAINING_CLASS, "class3"));
|
||||
|
||||
reader = writer.getReader();
|
||||
writer.close();
|
||||
searcher = newSearcher(reader);
|
||||
}
|
||||
|
||||
private void prepareTrainedIndexMultiClass() throws Exception {
|
||||
directory = newDirectory();
|
||||
RandomIndexWriter writer = new RandomIndexWriter(random(), directory);
|
||||
|
||||
//class1
|
||||
addDoc(writer, buildLuceneDocument(ID, "1",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2 word2 word2",
|
||||
AUTHOR, "Name Surname",
|
||||
TRAINING_CLASS, "class1",
|
||||
TRAINING_CLASS, "class2"
|
||||
));
|
||||
addDoc(writer, buildLuceneDocument(ID, "2",
|
||||
TITLE, "word1 word1",
|
||||
CONTENT, "word2 word2",
|
||||
AUTHOR, "Name Surname",
|
||||
TRAINING_CLASS, "class3",
|
||||
TRAINING_CLASS, "class2"
|
||||
));
|
||||
addDoc(writer, buildLuceneDocument(ID, "3",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2",
|
||||
AUTHOR, "Name Surname",
|
||||
TRAINING_CLASS, "class1",
|
||||
TRAINING_CLASS, "class2"
|
||||
));
|
||||
addDoc(writer, buildLuceneDocument(ID, "4",
|
||||
TITLE, "word1 word1 word1",
|
||||
CONTENT, "word2 word2 word2",
|
||||
AUTHOR, "Name Surname",
|
||||
TRAINING_CLASS, "class1",
|
||||
TRAINING_CLASS, "class2"
|
||||
));
|
||||
//class2
|
||||
addDoc(writer, buildLuceneDocument(ID, "5",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
TRAINING_CLASS, "class6",
|
||||
TRAINING_CLASS, "class4"
|
||||
));
|
||||
addDoc(writer, buildLuceneDocument(ID, "6",
|
||||
TITLE, "word4 word4",
|
||||
CONTENT, "word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
TRAINING_CLASS, "class5",
|
||||
TRAINING_CLASS, "class4"
|
||||
));
|
||||
addDoc(writer, buildLuceneDocument(ID, "7",
|
||||
TITLE, "word4 word4 word4",
|
||||
CONTENT, "word5 word5 word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
TRAINING_CLASS, "class6",
|
||||
TRAINING_CLASS, "class4"
|
||||
));
|
||||
addDoc(writer, buildLuceneDocument(ID, "8",
|
||||
TITLE, "word4",
|
||||
CONTENT, "word5 word5 word5 word5",
|
||||
AUTHOR, "Name1 Surname1",
|
||||
TRAINING_CLASS, "class6",
|
||||
TRAINING_CLASS, "class4"
|
||||
));
|
||||
|
||||
reader = writer.getReader();
|
||||
writer.close();
|
||||
searcher = newSearcher(reader);
|
||||
}
|
||||
|
||||
public static Document buildLuceneDocument(Object... fieldsAndValues) {
|
||||
Document luceneDoc = new Document();
|
||||
for (int i=0; i<fieldsAndValues.length; i+=2) {
|
||||
luceneDoc.add(newTextField((String)fieldsAndValues[i], (String)fieldsAndValues[i+1], Field.Store.YES));
|
||||
}
|
||||
return luceneDoc;
|
||||
}
|
||||
|
||||
private int addDoc(RandomIndexWriter writer, Document doc) throws IOException {
|
||||
writer.addDocument(doc);
|
||||
return writer.numDocs() - 1;
|
||||
}
|
||||
}
|
|
@ -16,31 +16,28 @@
|
|||
*/
|
||||
package org.apache.solr.update.processor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.util.Constants;
|
||||
import org.apache.solr.SolrTestCaseJ4;
|
||||
import org.apache.solr.client.solrj.impl.BinaryRequestWriter;
|
||||
import org.apache.solr.client.solrj.request.UpdateRequest;
|
||||
import org.apache.solr.common.SolrInputDocument;
|
||||
import org.apache.solr.common.params.MultiMapSolrParams;
|
||||
import org.apache.solr.common.params.SolrParams;
|
||||
import org.apache.solr.common.params.UpdateParams;
|
||||
import org.apache.solr.common.util.ContentStream;
|
||||
import org.apache.solr.common.util.ContentStreamBase;
|
||||
import org.apache.solr.common.util.NamedList;
|
||||
import org.apache.solr.core.SolrCore;
|
||||
import org.apache.solr.handler.UpdateRequestHandler;
|
||||
import org.apache.solr.request.LocalSolrQueryRequest;
|
||||
import org.apache.solr.request.SolrQueryRequest;
|
||||
import org.apache.solr.request.SolrQueryRequestBase;
|
||||
import org.apache.solr.response.SolrQueryResponse;
|
||||
import org.junit.Before;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
|
@ -359,21 +356,4 @@ public class SignatureUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
|
|||
private void addDoc(String doc) throws Exception {
|
||||
addDoc(doc, chain);
|
||||
}
|
||||
|
||||
static void addDoc(String doc, String chain) throws Exception {
|
||||
Map<String, String[]> params = new HashMap<>();
|
||||
MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
|
||||
params.put(UpdateParams.UPDATE_CHAIN, new String[] { chain });
|
||||
SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(),
|
||||
(SolrParams) mmparams) {
|
||||
};
|
||||
|
||||
UpdateRequestHandler handler = new UpdateRequestHandler();
|
||||
handler.init(null);
|
||||
ArrayList<ContentStream> streams = new ArrayList<>(2);
|
||||
streams.add(new ContentStreamBase.StringStream(doc));
|
||||
req.setContentStreams(streams);
|
||||
handler.handleRequestBody(req, new SolrQueryResponse());
|
||||
req.close();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,8 +25,6 @@ import org.junit.Test;
|
|||
|
||||
import java.util.Map;
|
||||
|
||||
import static org.apache.solr.update.processor.SignatureUpdateProcessorFactoryTest.addDoc;
|
||||
|
||||
public class TestPartialUpdateDeduplication extends SolrTestCaseJ4 {
|
||||
@BeforeClass
|
||||
public static void beforeClass() throws Exception {
|
||||
|
|
|
@ -83,7 +83,11 @@ import org.apache.solr.common.SolrInputDocument;
|
|||
import org.apache.solr.common.SolrInputField;
|
||||
import org.apache.solr.common.params.CommonParams;
|
||||
import org.apache.solr.common.params.ModifiableSolrParams;
|
||||
import org.apache.solr.common.params.MultiMapSolrParams;
|
||||
import org.apache.solr.common.params.SolrParams;
|
||||
import org.apache.solr.common.params.UpdateParams;
|
||||
import org.apache.solr.common.util.ContentStream;
|
||||
import org.apache.solr.common.util.ContentStreamBase;
|
||||
import org.apache.solr.common.util.ObjectReleaseTracker;
|
||||
import org.apache.solr.common.util.XML;
|
||||
import org.apache.solr.core.CoreContainer;
|
||||
|
@ -96,7 +100,9 @@ import org.apache.solr.core.SolrXmlConfig;
|
|||
import org.apache.solr.handler.UpdateRequestHandler;
|
||||
import org.apache.solr.request.LocalSolrQueryRequest;
|
||||
import org.apache.solr.request.SolrQueryRequest;
|
||||
import org.apache.solr.request.SolrQueryRequestBase;
|
||||
import org.apache.solr.request.SolrRequestHandler;
|
||||
import org.apache.solr.response.SolrQueryResponse;
|
||||
import org.apache.solr.schema.IndexSchema;
|
||||
import org.apache.solr.schema.SchemaField;
|
||||
import org.apache.solr.search.SolrIndexSearcher;
|
||||
|
@ -1009,6 +1015,22 @@ public abstract class SolrTestCaseJ4 extends LuceneTestCase {
|
|||
return out.toString();
|
||||
}
|
||||
|
||||
public static void addDoc(String doc, String updateRequestProcessorChain) throws Exception {
|
||||
Map<String, String[]> params = new HashMap<>();
|
||||
MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
|
||||
params.put(UpdateParams.UPDATE_CHAIN, new String[]{updateRequestProcessorChain});
|
||||
SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(),
|
||||
(SolrParams) mmparams) {
|
||||
};
|
||||
|
||||
UpdateRequestHandler handler = new UpdateRequestHandler();
|
||||
handler.init(null);
|
||||
ArrayList<ContentStream> streams = new ArrayList<>(2);
|
||||
streams.add(new ContentStreamBase.StringStream(doc));
|
||||
req.setContentStreams(streams);
|
||||
handler.handleRequestBody(req, new SolrQueryResponse());
|
||||
req.close();
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates an <add><doc>... XML String with options
|
||||
|
|
Loading…
Reference in New Issue