SOLR-8871 - various improvements to ClassificationURP

This commit is contained in:
Tommaso Teofili 2016-11-24 23:43:57 +01:00
parent e9e4715dd2
commit 5ad741eef8
11 changed files with 1014 additions and 357 deletions

View File

@ -16,22 +16,12 @@
*/
package org.apache.solr.uima.processor;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.util.LuceneTestCase.Slow;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.MultiMapSolrParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.params.UpdateParams;
import org.apache.solr.common.util.ContentStream;
import org.apache.solr.common.util.ContentStreamBase;
import org.apache.solr.core.SolrCore;
import org.apache.solr.handler.UpdateRequestHandler;
import org.apache.solr.request.SolrQueryRequestBase;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.uima.processor.SolrUIMAConfiguration.MapField;
import org.apache.solr.update.processor.UpdateRequestProcessor;
import org.apache.solr.update.processor.UpdateRequestProcessorChain;
@ -188,19 +178,4 @@ public class UIMAUpdateRequestProcessorTest extends SolrTestCaseJ4 {
}
}
private void addDoc(String chain, String doc) throws Exception {
Map<String, String[]> params = new HashMap<>();
params.put(UpdateParams.UPDATE_CHAIN, new String[] { chain });
MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(), (SolrParams) mmparams) {
};
UpdateRequestHandler handler = new UpdateRequestHandler();
handler.init(null);
ArrayList<ContentStream> streams = new ArrayList<>(2);
streams.add(new ContentStreamBase.StringStream(doc));
req.setContentStreams(streams);
handler.handleRequestBody(req, new SolrQueryResponse());
}
}

View File

@ -19,6 +19,7 @@ package org.apache.solr.update.processor;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
@ -33,6 +34,7 @@ import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.update.AddUpdateCommand;
import org.apache.solr.update.processor.ClassificationUpdateProcessorFactory.Algorithm;
/**
* This Class is a Request Update Processor to classify the document in input and add a field
@ -42,43 +44,54 @@ import org.apache.solr.update.AddUpdateCommand;
class ClassificationUpdateProcessor
extends UpdateRequestProcessor {
private String classFieldName; // the field to index the assigned class
private final String trainingClassField;
private final String predictedClassField;
private final int maxOutputClasses;
private DocumentClassifier<BytesRef> classifier;
/**
* Sole constructor
*
* @param inputFieldNames fields to be used as classifier's inputs
* @param classFieldName field to be used as classifier's output
* @param minDf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minDocFreq}, in case algorithm is {@code "knn"}
* @param minTf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minTermFreq}, in case algorithm is {@code "knn"}
* @param k setting for k nearest neighbors to analyze, in case algorithm is {@code "knn"}
* @param algorithm the name of the classifier to use
* @param classificationParams classification advanced params
* @param next next update processor in the chain
* @param indexReader index reader
* @param schema schema
*/
public ClassificationUpdateProcessor(String[] inputFieldNames, String classFieldName, int minDf, int minTf, int k, String algorithm,
UpdateRequestProcessor next, IndexReader indexReader, IndexSchema schema) {
public ClassificationUpdateProcessor(ClassificationUpdateProcessorParams classificationParams, UpdateRequestProcessor next, IndexReader indexReader, IndexSchema schema) {
super(next);
this.classFieldName = classFieldName;
Map<String, Analyzer> field2analyzer = new HashMap<String, Analyzer>();
this.trainingClassField = classificationParams.getTrainingClassField();
this.predictedClassField = classificationParams.getPredictedClassField();
this.maxOutputClasses = classificationParams.getMaxPredictedClasses();
String[] inputFieldNamesWithBoost = classificationParams.getInputFieldNames();
Algorithm classificationAlgorithm = classificationParams.getAlgorithm();
Map<String, Analyzer> field2analyzer = new HashMap<>();
String[] inputFieldNames = this.removeBoost(inputFieldNamesWithBoost);
for (String fieldName : inputFieldNames) {
SchemaField fieldFromSolrSchema = schema.getField(fieldName);
Analyzer indexAnalyzer = fieldFromSolrSchema.getType().getQueryAnalyzer();
field2analyzer.put(fieldName, indexAnalyzer);
}
switch (algorithm) {
case "knn":
classifier = new KNearestNeighborDocumentClassifier(indexReader, null, null, k, minDf, minTf, classFieldName, field2analyzer, inputFieldNames);
switch (classificationAlgorithm) {
case KNN:
classifier = new KNearestNeighborDocumentClassifier(indexReader, null, classificationParams.getTrainingFilterQuery(), classificationParams.getK(), classificationParams.getMinDf(), classificationParams.getMinTf(), trainingClassField, field2analyzer, inputFieldNamesWithBoost);
break;
case "bayes":
classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, classFieldName, field2analyzer, inputFieldNames);
case BAYES:
classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, trainingClassField, field2analyzer, inputFieldNamesWithBoost);
break;
}
}
private String[] removeBoost(String[] inputFieldNamesWithBoost) {
String[] inputFieldNames = new String[inputFieldNamesWithBoost.length];
for (int i = 0; i < inputFieldNamesWithBoost.length; i++) {
String singleFieldNameWithBoost = inputFieldNamesWithBoost[i];
String[] fieldName2boost = singleFieldNameWithBoost.split("\\^");
inputFieldNames[i] = fieldName2boost[0];
}
return inputFieldNames;
}
/**
* @param cmd the update command in input containing the Document to classify
* @throws IOException If there is a low-level I/O error
@ -89,12 +102,14 @@ class ClassificationUpdateProcessor
SolrInputDocument doc = cmd.getSolrInputDocument();
Document luceneDocument = cmd.getLuceneDocument();
String assignedClass;
Object documentClass = doc.getFieldValue(classFieldName);
Object documentClass = doc.getFieldValue(trainingClassField);
if (documentClass == null) {
ClassificationResult<BytesRef> classificationResult = classifier.assignClass(luceneDocument);
if (classificationResult != null) {
assignedClass = classificationResult.getAssignedClass().utf8ToString();
doc.addField(classFieldName, assignedClass);
List<ClassificationResult<BytesRef>> assignedClassifications = classifier.getClasses(luceneDocument, maxOutputClasses);
if (assignedClassifications != null) {
for (ClassificationResult<BytesRef> singleClassification : assignedClassifications) {
assignedClass = singleClassification.getAssignedClass().utf8ToString();
doc.addField(predictedClassField, assignedClass);
}
}
}
super.processAdd(cmd);

View File

@ -18,12 +18,17 @@
package org.apache.solr.update.processor;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Query;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.search.LuceneQParser;
import org.apache.solr.search.SyntaxError;
import static org.apache.solr.update.processor.ClassificationUpdateProcessorFactory.Algorithm.KNN;
/**
* This class implements an UpdateProcessorFactory for the Classification Update Processor.
@ -33,49 +38,67 @@ public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessor
// Update Processor Config params
private static final String INPUT_FIELDS_PARAM = "inputFields";
private static final String CLASS_FIELD_PARAM = "classField";
private static final String TRAINING_CLASS_FIELD_PARAM = "classField";
private static final String PREDICTED_CLASS_FIELD_PARAM = "predictedClassField";
private static final String MAX_CLASSES_TO_ASSIGN_PARAM = "predictedClass.maxCount";
private static final String ALGORITHM_PARAM = "algorithm";
private static final String KNN_MIN_TF_PARAM = "knn.minTf";
private static final String KNN_MIN_DF_PARAM = "knn.minDf";
private static final String KNN_K_PARAM = "knn.k";
private static final String KNN_FILTER_QUERY = "knn.filterQuery";
public enum Algorithm {KNN, BAYES}
//Update Processor Defaults
private static final int DEFAULT_MAX_CLASSES_TO_ASSIGN = 1;
private static final int DEFAULT_MIN_TF = 1;
private static final int DEFAULT_MIN_DF = 1;
private static final int DEFAULT_K = 10;
private static final String DEFAULT_ALGORITHM = "knn";
private static final Algorithm DEFAULT_ALGORITHM = KNN;
private String[] inputFieldNames; // the array of fields to be sent to the Classifier
private String classFieldName; // the field containing the class for the Document
private String algorithm; // the Classification Algorithm to use - currently 'knn' or 'bayes'
private int minTf; // knn specific - the minimum Term Frequency for considering a term
private int minDf; // knn specific - the minimum Document Frequency for considering a term
private int k; // knn specific - thw window of top results to evaluate, when assigning the class
private SolrParams params;
private ClassificationUpdateProcessorParams classificationParams;
@Override
public void init(final NamedList args) {
if (args != null) {
SolrParams params = SolrParams.toSolrParams(args);
params = SolrParams.toSolrParams(args);
classificationParams = new ClassificationUpdateProcessorParams();
String fieldNames = params.get(INPUT_FIELDS_PARAM);// must be a comma separated list of fields
checkNotNull(INPUT_FIELDS_PARAM, fieldNames);
inputFieldNames = fieldNames.split("\\,");
classificationParams.setInputFieldNames(fieldNames.split("\\,"));
classFieldName = params.get(CLASS_FIELD_PARAM);
checkNotNull(CLASS_FIELD_PARAM, classFieldName);
String trainingClassField = (params.get(TRAINING_CLASS_FIELD_PARAM));
checkNotNull(TRAINING_CLASS_FIELD_PARAM, trainingClassField);
classificationParams.setTrainingClassField(trainingClassField);
algorithm = params.get(ALGORITHM_PARAM);
if (algorithm == null)
algorithm = DEFAULT_ALGORITHM;
String predictedClassField = (params.get(PREDICTED_CLASS_FIELD_PARAM));
if (predictedClassField == null || predictedClassField.isEmpty()) {
predictedClassField = trainingClassField;
}
classificationParams.setPredictedClassField(predictedClassField);
minTf = getIntParam(params, KNN_MIN_TF_PARAM, DEFAULT_MIN_TF);
minDf = getIntParam(params, KNN_MIN_DF_PARAM, DEFAULT_MIN_DF);
k = getIntParam(params, KNN_K_PARAM, DEFAULT_K);
classificationParams.setMaxPredictedClasses(getIntParam(params, MAX_CLASSES_TO_ASSIGN_PARAM, DEFAULT_MAX_CLASSES_TO_ASSIGN));
String algorithmString = params.get(ALGORITHM_PARAM);
Algorithm classificationAlgorithm;
try {
if (algorithmString == null || Algorithm.valueOf(algorithmString.toUpperCase()) == null) {
classificationAlgorithm = DEFAULT_ALGORITHM;
} else {
classificationAlgorithm = Algorithm.valueOf(algorithmString.toUpperCase());
}
} catch (IllegalArgumentException e) {
throw new SolrException
(SolrException.ErrorCode.SERVER_ERROR,
"Classification UpdateProcessor Algorithm: '" + algorithmString + "' not supported");
}
classificationParams.setAlgorithm(classificationAlgorithm);
classificationParams.setMinTf(getIntParam(params, KNN_MIN_TF_PARAM, DEFAULT_MIN_TF));
classificationParams.setMinDf(getIntParam(params, KNN_MIN_DF_PARAM, DEFAULT_MIN_DF));
classificationParams.setK(getIntParam(params, KNN_K_PARAM, DEFAULT_K));
}
}
@ -108,116 +131,34 @@ public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessor
@Override
public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
String trainingFilterQueryString = (params.get(KNN_FILTER_QUERY));
try {
if (trainingFilterQueryString != null && !trainingFilterQueryString.isEmpty()) {
Query trainingFilterQuery = this.parseFilterQuery(trainingFilterQueryString, params, req);
classificationParams.setTrainingFilterQuery(trainingFilterQuery);
}
} catch (SyntaxError | RuntimeException syntaxError) {
throw new SolrException
(SolrException.ErrorCode.SERVER_ERROR,
"Classification UpdateProcessor Training Filter Query: '" + trainingFilterQueryString + "' is not supported", syntaxError);
}
IndexSchema schema = req.getSchema();
IndexReader indexReader = req.getSearcher().getIndexReader();
return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, indexReader, schema);
return new ClassificationUpdateProcessor(classificationParams, next, indexReader, schema);
}
/**
* get field names used as classifier's inputs
*
* @return the input field names
*/
public String[] getInputFieldNames() {
return inputFieldNames;
private Query parseFilterQuery(String trainingFilterQueryString, SolrParams params, SolrQueryRequest req) throws SyntaxError {
LuceneQParser parser = new LuceneQParser(trainingFilterQueryString, null, params, req);
return parser.parse();
}
/**
* set field names used as classifier's inputs
*
* @param inputFieldNames the input field names
*/
public void setInputFieldNames(String[] inputFieldNames) {
this.inputFieldNames = inputFieldNames;
public ClassificationUpdateProcessorParams getClassificationParams() {
return classificationParams;
}
/**
* get field names used as classifier's output
*
* @return the output field name
*/
public String getClassFieldName() {
return classFieldName;
}
/**
* set field names used as classifier's output
*
* @param classFieldName the output field name
*/
public void setClassFieldName(String classFieldName) {
this.classFieldName = classFieldName;
}
/**
* get the name of the classifier algorithm used
*
* @return the classifier algorithm used
*/
public String getAlgorithm() {
return algorithm;
}
/**
* set the name of the classifier algorithm used
*
* @param algorithm the classifier algorithm used
*/
public void setAlgorithm(String algorithm) {
this.algorithm = algorithm;
}
/**
* get the min term frequency value to be used in case algorithm is {@code "knn"}
*
* @return the min term frequency
*/
public int getMinTf() {
return minTf;
}
/**
* set the min term frequency value to be used in case algorithm is {@code "knn"}
*
* @param minTf the min term frequency
*/
public void setMinTf(int minTf) {
this.minTf = minTf;
}
/**
* get the min document frequency value to be used in case algorithm is {@code "knn"}
*
* @return the min document frequency
*/
public int getMinDf() {
return minDf;
}
/**
* set the min document frequency value to be used in case algorithm is {@code "knn"}
*
* @param minDf the min document frequency
*/
public void setMinDf(int minDf) {
this.minDf = minDf;
}
/**
* get the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"}
*
* @return the no. of neighbors to analyze
*/
public int getK() {
return k;
}
/**
* set the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"}
*
* @param k the no. of neighbors to analyze
*/
public void setK(int k) {
this.k = k;
public void setClassificationParams(ClassificationUpdateProcessorParams classificationParams) {
this.classificationParams = classificationParams;
}
}

View File

@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.update.processor;
import org.apache.lucene.search.Query;
public class ClassificationUpdateProcessorParams {
private String[] inputFieldNames; // the array of fields to be sent to the Classifier
private Query trainingFilterQuery; // a filter query to reduce the training set to a subset
private String trainingClassField; // the field containing the class for the Document
private String predictedClassField; // the field that will contain the predicted class
private int maxPredictedClasses; // the max number of classes to assign
private ClassificationUpdateProcessorFactory.Algorithm algorithm; // the Classification Algorithm to use - currently 'knn' or 'bayes'
private int minTf; // knn specific - the minimum Term Frequency for considering a term
private int minDf; // knn specific - the minimum Document Frequency for considering a term
private int k; // knn specific - thw window of top results to evaluate, when assigning the class
public String[] getInputFieldNames() {
return inputFieldNames;
}
public void setInputFieldNames(String[] inputFieldNames) {
this.inputFieldNames = inputFieldNames;
}
public Query getTrainingFilterQuery() {
return trainingFilterQuery;
}
public void setTrainingFilterQuery(Query trainingFilterQuery) {
this.trainingFilterQuery = trainingFilterQuery;
}
public String getTrainingClassField() {
return trainingClassField;
}
public void setTrainingClassField(String trainingClassField) {
this.trainingClassField = trainingClassField;
}
public String getPredictedClassField() {
return predictedClassField;
}
public void setPredictedClassField(String predictedClassField) {
this.predictedClassField = predictedClassField;
}
public int getMaxPredictedClasses() {
return maxPredictedClasses;
}
public void setMaxPredictedClasses(int maxPredictedClasses) {
this.maxPredictedClasses = maxPredictedClasses;
}
public ClassificationUpdateProcessorFactory.Algorithm getAlgorithm() {
return algorithm;
}
public void setAlgorithm(ClassificationUpdateProcessorFactory.Algorithm algorithm) {
this.algorithm = algorithm;
}
public int getMinTf() {
return minTf;
}
public void setMinTf(int minTf) {
this.minTf = minTf;
}
public int getMinDf() {
return minDf;
}
public void setMinDf(int minDf) {
this.minDf = minDf;
}
public int getK() {
return k;
}
public void setK(int k) {
this.k = k;
}
}

View File

@ -47,6 +47,21 @@
<str name="knn.minTf">1</str>
<str name="knn.minDf">1</str>
<str name="knn.k">5</str>
<str name="knn.filterQuery">cat:(class1 OR class2)</str>
</processor>
<processor class="solr.RunUpdateProcessorFactory"/>
</updateRequestProcessorChain>
<updateRequestProcessorChain name="classification-unsupported-filterQuery">
<processor class="solr.ClassificationUpdateProcessorFactory">
<str name="inputFields">title,content,author</str>
<str name="classField">cat</str>
<!-- Knn algorithm specific-->
<str name="algorithm">knn</str>
<str name="knn.minTf">1</str>
<str name="knn.minDf">1</str>
<str name="knn.k">5</str>
<str name="knn.filterQuery">not valid ( lucene query</str>
</processor>
<processor class="solr.RunUpdateProcessorFactory"/>
</updateRequestProcessorChain>

View File

@ -14,71 +14,31 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.update.processor;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.MultiMapSolrParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.params.UpdateParams;
import org.apache.solr.common.util.ContentStream;
import org.apache.solr.common.util.ContentStreamBase;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.UpdateRequestHandler;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.request.SolrQueryRequestBase;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.search.SolrIndexSearcher;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.hamcrest.core.Is.is;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link ClassificationUpdateProcessor} and {@link ClassificationUpdateProcessorFactory}
* Tests for {@link ClassificationUpdateProcessorFactory}
*/
public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
// field names are used in accordance with the solrconfig and schema supplied
private static final String ID = "id";
private static final String TITLE = "title";
private static final String CONTENT = "content";
private static final String AUTHOR = "author";
private static final String CLASS = "cat";
private static final String CHAIN = "classification";
private ClassificationUpdateProcessorFactory cFactoryToTest = new ClassificationUpdateProcessorFactory();
private NamedList args = new NamedList<String>();
@BeforeClass
public static void beforeClass() throws Exception {
System.setProperty("enable.update.log", "false");
initCore("solrconfig-classification.xml", "schema-classification.xml");
}
@Override
@Before
public void setUp() throws Exception {
super.setUp();
clearIndex();
assertU(commit());
}
@Before
public void initArgs() {
args.add("inputFields", "inputField1,inputField2");
args.add("classField", "classField1");
args.add("predictedClassField", "classFieldX");
args.add("algorithm", "bayes");
args.add("knn.k", "9");
args.add("knn.minDf", "8");
@ -86,22 +46,23 @@ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
}
@Test
public void testFullInit() {
public void init_fullArgs_shouldInitFullClassificationParams() {
cFactoryToTest.init(args);
ClassificationUpdateProcessorParams classificationParams = cFactoryToTest.getClassificationParams();
String[] inputFieldNames = cFactoryToTest.getInputFieldNames();
String[] inputFieldNames = classificationParams.getInputFieldNames();
assertEquals("inputField1", inputFieldNames[0]);
assertEquals("inputField2", inputFieldNames[1]);
assertEquals("classField1", cFactoryToTest.getClassFieldName());
assertEquals("bayes", cFactoryToTest.getAlgorithm());
assertEquals(8, cFactoryToTest.getMinDf());
assertEquals(10, cFactoryToTest.getMinTf());
assertEquals(9, cFactoryToTest.getK());
assertEquals("classField1", classificationParams.getTrainingClassField());
assertEquals("classFieldX", classificationParams.getPredictedClassField());
assertEquals(ClassificationUpdateProcessorFactory.Algorithm.BAYES, classificationParams.getAlgorithm());
assertEquals(8, classificationParams.getMinDf());
assertEquals(10, classificationParams.getMinTf());
assertEquals(9, classificationParams.getK());
}
@Test
public void testInitEmptyInputField() {
public void init_emptyInputFields_shouldThrowExceptionWithDetailedMessage() {
args.removeAll("inputFields");
try {
cFactoryToTest.init(args);
@ -111,7 +72,7 @@ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
}
@Test
public void testInitEmptyClassField() {
public void init_emptyClassField_shouldThrowExceptionWithDetailedMessage() {
args.removeAll("classField");
try {
cFactoryToTest.init(args);
@ -121,114 +82,53 @@ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
}
@Test
public void testDefaults() {
public void init_emptyPredictedClassField_shouldDefaultToTrainingClassField() {
args.removeAll("predictedClassField");
cFactoryToTest.init(args);
ClassificationUpdateProcessorParams classificationParams = cFactoryToTest.getClassificationParams();
assertThat(classificationParams.getPredictedClassField(), is("classField1"));
}
@Test
public void init_unsupportedAlgorithm_shouldThrowExceptionWithDetailedMessage() {
args.removeAll("algorithm");
args.add("algorithm", "unsupported");
try {
cFactoryToTest.init(args);
} catch (SolrException e) {
assertEquals("Classification UpdateProcessor Algorithm: 'unsupported' not supported", e.getMessage());
}
}
@Test
public void init_unsupportedFilterQuery_shouldThrowExceptionWithDetailedMessage() {
UpdateRequestProcessor mockProcessor = mock(UpdateRequestProcessor.class);
SolrQueryRequest mockRequest = mock(SolrQueryRequest.class);
SolrQueryResponse mockResponse = mock(SolrQueryResponse.class);
args.add("knn.filterQuery", "not supported query");
try {
cFactoryToTest.init(args);
/* parsing failure happens because of the mocks, fine enough to check a proper exception propagation */
cFactoryToTest.getInstance(mockRequest, mockResponse, mockProcessor);
} catch (SolrException e) {
assertEquals("Classification UpdateProcessor Training Filter Query: 'not supported query' is not supported", e.getMessage());
}
}
@Test
public void init_emptyArgs_shouldDefaultClassificationParams() {
args.removeAll("algorithm");
args.removeAll("knn.k");
args.removeAll("knn.minDf");
args.removeAll("knn.minTf");
cFactoryToTest.init(args);
assertEquals("knn", cFactoryToTest.getAlgorithm());
assertEquals(1, cFactoryToTest.getMinDf());
assertEquals(1, cFactoryToTest.getMinTf());
assertEquals(10, cFactoryToTest.getK());
}
ClassificationUpdateProcessorParams classificationParams = cFactoryToTest.getClassificationParams();
@Test
public void testBasicClassification() throws Exception {
prepareTrainedIndex();
// To be classified,we index documents without a class and verify the expected one is returned
addDoc(adoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 ",
AUTHOR, "Name1 Surname1"));
addDoc(adoc(ID, "11",
TITLE, "word1 word1",
CONTENT, "word2 word2",
AUTHOR, "Name Surname"));
addDoc(commit());
Document doc10 = getDoc("10");
assertEquals("class2", doc10.get(CLASS));
Document doc11 = getDoc("11");
assertEquals("class1", doc11.get(CLASS));
}
/**
* Index some example documents with a class manually assigned.
* This will be our trained model.
*
* @throws Exception If there is a low-level I/O error
*/
private void prepareTrainedIndex() throws Exception {
//class1
addDoc(adoc(ID, "1",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "Name Surname",
CLASS, "class1"));
addDoc(adoc(ID, "2",
TITLE, "word1 word1",
CONTENT, "word2 word2",
AUTHOR, "Name Surname",
CLASS, "class1"));
addDoc(adoc(ID, "3",
TITLE, "word1 word1 word1",
CONTENT, "word2",
AUTHOR, "Name Surname",
CLASS, "class1"));
addDoc(adoc(ID, "4",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "Name Surname",
CLASS, "class1"));
//class2
addDoc(adoc(ID, "5",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5",
AUTHOR, "Name1 Surname1",
CLASS, "class2"));
addDoc(adoc(ID, "6",
TITLE, "word4 word4",
CONTENT, "word5",
AUTHOR, "Name1 Surname1",
CLASS, "class2"));
addDoc(adoc(ID, "7",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 word5",
AUTHOR, "Name1 Surname1",
CLASS, "class2"));
addDoc(adoc(ID, "8",
TITLE, "word4",
CONTENT, "word5 word5 word5 word5",
AUTHOR, "Name1 Surname1",
CLASS, "class2"));
addDoc(commit());
}
private Document getDoc(String id) throws IOException {
try (SolrQueryRequest req = req()) {
SolrIndexSearcher searcher = req.getSearcher();
TermQuery query = new TermQuery(new Term(ID, id));
TopDocs doc1 = searcher.search(query, 1);
ScoreDoc scoreDoc = doc1.scoreDocs[0];
return searcher.doc(scoreDoc.doc);
}
}
static void addDoc(String doc) throws Exception {
Map<String, String[]> params = new HashMap<>();
MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
params.put(UpdateParams.UPDATE_CHAIN, new String[]{CHAIN});
SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(),
(SolrParams) mmparams) {
};
UpdateRequestHandler handler = new UpdateRequestHandler();
handler.init(null);
ArrayList<ContentStream> streams = new ArrayList<>(2);
streams.add(new ContentStreamBase.StringStream(doc));
req.setContentStreams(streams);
handler.handleRequestBody(req, new SolrQueryResponse());
req.close();
assertEquals(ClassificationUpdateProcessorFactory.Algorithm.KNN, classificationParams.getAlgorithm());
assertEquals(1, classificationParams.getMinDf());
assertEquals(1, classificationParams.getMinTf());
assertEquals(10, classificationParams.getK());
}
}

View File

@ -0,0 +1,192 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.update.processor;
import java.io.IOException;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.search.SolrIndexSearcher;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.hamcrest.core.Is.is;
/**
* Tests for {@link ClassificationUpdateProcessor} and {@link ClassificationUpdateProcessorFactory}
*/
public class ClassificationUpdateProcessorIntegrationTest extends SolrTestCaseJ4 {
/* field names are used in accordance with the solrconfig and schema supplied */
private static final String ID = "id";
private static final String TITLE = "title";
private static final String CONTENT = "content";
private static final String AUTHOR = "author";
private static final String CLASS = "cat";
private static final String CHAIN = "classification";
private static final String BROKEN_CHAIN_FILTER_QUERY = "classification-unsupported-filterQuery";
private ClassificationUpdateProcessorFactory cFactoryToTest = new ClassificationUpdateProcessorFactory();
private NamedList args = new NamedList<String>();
@BeforeClass
public static void beforeClass() throws Exception {
System.setProperty("enable.update.log", "false");
initCore("solrconfig-classification.xml", "schema-classification.xml");
}
@Override
@Before
public void setUp() throws Exception {
super.setUp();
clearIndex();
assertU(commit());
}
@Test
public void classify_fullConfiguration_shouldAutoClassify() throws Exception {
indexTrainingSet();
// To be classified,we index documents without a class and verify the expected one is returned
addDoc(adoc(ID, "22",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 ",
AUTHOR, "Name1 Surname1"), CHAIN);
addDoc(adoc(ID, "21",
TITLE, "word1 word1",
CONTENT, "word2 word2",
AUTHOR, "Name Surname"), CHAIN);
addDoc(commit());
Document doc22 = getDoc("22");
assertThat(doc22.get(CLASS),is("class2"));
Document doc21 = getDoc("21");
assertThat(doc21.get(CLASS),is("class1"));
}
@Test
public void classify_unsupportedFilterQueryConfiguration_shouldThrowExceptionWithDetailedMessage() throws Exception {
indexTrainingSet();
try {
addDoc(adoc(ID, "21",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 ",
AUTHOR, "Name1 Surname1"), BROKEN_CHAIN_FILTER_QUERY);
addDoc(adoc(ID, "22",
TITLE, "word1 word1",
CONTENT, "word2 word2",
AUTHOR, "Name Surname"), BROKEN_CHAIN_FILTER_QUERY);
addDoc(commit());
} catch (SolrException e) {
assertEquals("Classification UpdateProcessor Training Filter Query: 'not valid ( lucene query' is not supported", e.getMessage());
}
}
/**
* Index some example documents with a class manually assigned.
* This will be our trained model.
*
* @throws Exception If there is a low-level I/O error
*/
private void indexTrainingSet() throws Exception {
//class1
addDoc(adoc(ID, "1",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "Name Surname",
CLASS, "class1"), CHAIN);
addDoc(adoc(ID, "2",
TITLE, "word1 word1",
CONTENT, "word2 word2",
AUTHOR, "Name Surname",
CLASS, "class1"), CHAIN);
addDoc(adoc(ID, "3",
TITLE, "word1 word1 word1",
CONTENT, "word2",
AUTHOR, "Name Surname",
CLASS, "class1"), CHAIN);
addDoc(adoc(ID, "4",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "Name Surname",
CLASS, "class1"), CHAIN);
//class2
addDoc(adoc(ID, "5",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5",
AUTHOR, "Name Surname",
CLASS, "class2"), CHAIN);
addDoc(adoc(ID, "6",
TITLE, "word4 word4",
CONTENT, "word5",
AUTHOR, "Name Surname",
CLASS, "class2"), CHAIN);
addDoc(adoc(ID, "7",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 word5",
AUTHOR, "Name Surname",
CLASS, "class2"), CHAIN);
addDoc(adoc(ID, "8",
TITLE, "word4",
CONTENT, "word5 word5 word5 word5",
AUTHOR, "Name Surname",
CLASS, "class2"), CHAIN);
//class3
addDoc(adoc(ID, "9",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5",
AUTHOR, "Name1 Surname1",
CLASS, "class3"), CHAIN);
addDoc(adoc(ID, "10",
TITLE, "word4 word4",
CONTENT, "word5",
AUTHOR, "Name1 Surname1",
CLASS, "class3"), CHAIN);
addDoc(adoc(ID, "11",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 word5",
AUTHOR, "Name1 Surname1",
CLASS, "class3"), CHAIN);
addDoc(adoc(ID, "12",
TITLE, "word4",
CONTENT, "word5 word5 word5 word5",
AUTHOR, "Name1 Surname1",
CLASS, "class3"), CHAIN);
addDoc(commit());
}
private Document getDoc(String id) throws IOException {
try (SolrQueryRequest req = req()) {
SolrIndexSearcher searcher = req.getSearcher();
TermQuery query = new TermQuery(new Term(ID, id));
TopDocs doc1 = searcher.search(query, 1);
ScoreDoc scoreDoc = doc1.scoreDocs[0];
return searcher.doc(scoreDoc.doc);
}
}
private void addDoc(String doc) throws Exception {
addDoc(doc, CHAIN);
}
}

View File

@ -0,0 +1,507 @@
package org.apache.solr.update.processor;
import java.io.IOException;
import java.util.ArrayList;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.analysis.MockTokenizer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.update.AddUpdateCommand;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.hamcrest.core.Is.is;
import static org.mockito.Mockito.mock;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Tests for {@link ClassificationUpdateProcessor}
*/
public class ClassificationUpdateProcessorTest extends SolrTestCaseJ4 {
/* field names are used in accordance with the solrconfig and schema supplied */
private static final String ID = "id";
private static final String TITLE = "title";
private static final String CONTENT = "content";
private static final String AUTHOR = "author";
private static final String TRAINING_CLASS = "cat";
private static final String PREDICTED_CLASS = "predicted";
public static final String KNN = "knn";
protected Directory directory;
protected IndexReader reader;
protected IndexSearcher searcher;
protected Analyzer analyzer = new MockAnalyzer(random(), MockTokenizer.WHITESPACE, false);
private ClassificationUpdateProcessor updateProcessorToTest;
@BeforeClass
public static void beforeClass() throws Exception {
System.setProperty("enable.update.log", "false");
initCore("solrconfig-classification.xml", "schema-classification.xml");
}
@Override
public void setUp() throws Exception {
super.setUp();
}
@Override
public void tearDown() throws Exception {
reader.close();
directory.close();
analyzer.close();
super.tearDown();
}
@Test
public void classificationMonoClass_predictedClassFieldSet_shouldAssignClassInPredictedClassField() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMonoClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params = initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
params.setPredictedClassField(PREDICTED_CLASS);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
assertThat(unseenDocument1.getFieldValue(PREDICTED_CLASS),is("class1"));
}
@Test
public void knnMonoClass_sampleParams_shouldAssignCorrectClass() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMonoClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params = initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class1"));
}
@Test
public void knnMonoClass_boostFields_shouldAssignCorrectClass() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMonoClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params = initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
params.setInputFieldNames(new String[]{TITLE + "^1.5", CONTENT + "^0.5", AUTHOR + "^2.5"});
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class2"));
}
@Test
public void bayesMonoClass_sampleParams_shouldAssignCorrectClass() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMonoClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class1"));
}
@Test
public void knnMonoClass_contextQueryFiltered_shouldAssignCorrectClass() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMonoClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "a");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
Query class3DocsChunk=new TermQuery(new Term(TITLE,"word6"));
params.setTrainingFilterQuery(class3DocsChunk);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class3"));
}
@Test
public void bayesMonoClass_boostFields_shouldAssignCorrectClass() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMonoClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
params.setInputFieldNames(new String[]{TITLE+"^1.5",CONTENT+"^0.5",AUTHOR+"^2.5"});
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class2"));
}
@Test
public void knnClassification_maxOutputClassesGreaterThanAvailable_shouldAssignCorrectClass() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMultiClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
params.setMaxPredictedClasses(100);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
assertThat(assignedClasses.get(0),is("class2"));
assertThat(assignedClasses.get(1),is("class1"));
}
@Test
public void knnMultiClass_maxOutputClasses2_shouldAssignMax2Classes() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMultiClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
params.setMaxPredictedClasses(2);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
assertThat(assignedClasses.size(),is(2));
assertThat(assignedClasses.get(0),is("class2"));
assertThat(assignedClasses.get(1),is("class1"));
}
@Test
public void bayesMultiClass_maxOutputClasses2_shouldAssignMax2Classes() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMultiClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
params.setMaxPredictedClasses(2);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
assertThat(assignedClasses.size(),is(2));
assertThat(assignedClasses.get(0),is("class2"));
assertThat(assignedClasses.get(1),is("class1"));
}
@Test
public void knnMultiClass_boostFieldsMaxOutputClasses2_shouldAssignMax2Classes() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMultiClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
params.setInputFieldNames(new String[]{TITLE+"^1.5",CONTENT+"^0.5",AUTHOR+"^2.5"});
params.setMaxPredictedClasses(2);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
assertThat(assignedClasses.size(),is(2));
assertThat(assignedClasses.get(0),is("class4"));
assertThat(assignedClasses.get(1),is("class6"));
}
@Test
public void bayesMultiClass_boostFieldsMaxOutputClasses2_shouldAssignMax2Classes() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMultiClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
params.setInputFieldNames(new String[]{TITLE+"^1.5",CONTENT+"^0.5",AUTHOR+"^2.5"});
params.setMaxPredictedClasses(2);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
assertThat(assignedClasses.size(),is(2));
assertThat(assignedClasses.get(0),is("class4"));
assertThat(assignedClasses.get(1),is("class6"));
}
private ClassificationUpdateProcessorParams initParams(ClassificationUpdateProcessorFactory.Algorithm classificationAlgorithm) {
ClassificationUpdateProcessorParams params= new ClassificationUpdateProcessorParams();
params.setInputFieldNames(new String[]{TITLE,CONTENT,AUTHOR});
params.setTrainingClassField(TRAINING_CLASS);
params.setPredictedClassField(TRAINING_CLASS);
params.setMinTf(1);
params.setMinDf(1);
params.setK(5);
params.setAlgorithm(classificationAlgorithm);
params.setMaxPredictedClasses(1);
return params;
}
/**
* Index some example documents with a class manually assigned.
* This will be our trained model.
*
* @throws Exception If there is a low-level I/O error
*/
private void prepareTrainedIndexMonoClass() throws Exception {
directory = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), directory);
//class1
addDoc(writer, buildLuceneDocument(ID, "1",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "a",
TRAINING_CLASS, "class1"));
addDoc(writer, buildLuceneDocument(ID, "2",
TITLE, "word1 word1",
CONTENT, "word2 word2",
AUTHOR, "a",
TRAINING_CLASS, "class1"));
addDoc(writer, buildLuceneDocument(ID, "3",
TITLE, "word1 word1 word1",
CONTENT, "word2",
AUTHOR, "a",
TRAINING_CLASS, "class1"));
addDoc(writer, buildLuceneDocument(ID, "4",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "a",
TRAINING_CLASS, "class1"));
//class2
addDoc(writer, buildLuceneDocument(ID, "5",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5",
AUTHOR, "c",
TRAINING_CLASS, "class2"));
addDoc(writer, buildLuceneDocument(ID, "6",
TITLE, "word4 word4",
CONTENT, "word5",
AUTHOR, "c",
TRAINING_CLASS, "class2"));
addDoc(writer, buildLuceneDocument(ID, "7",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 word5",
AUTHOR, "c",
TRAINING_CLASS, "class2"));
addDoc(writer, buildLuceneDocument(ID, "8",
TITLE, "word4",
CONTENT, "word5 word5 word5 word5",
AUTHOR, "c",
TRAINING_CLASS, "class2"));
//class3
addDoc(writer, buildLuceneDocument(ID, "9",
TITLE, "word6",
CONTENT, "word7",
AUTHOR, "a",
TRAINING_CLASS, "class3"));
addDoc(writer, buildLuceneDocument(ID, "10",
TITLE, "word6",
CONTENT, "word7",
AUTHOR, "a",
TRAINING_CLASS, "class3"));
addDoc(writer, buildLuceneDocument(ID, "11",
TITLE, "word6",
CONTENT, "word7",
AUTHOR, "a",
TRAINING_CLASS, "class3"));
addDoc(writer, buildLuceneDocument(ID, "12",
TITLE, "word6",
CONTENT, "word7",
AUTHOR, "a",
TRAINING_CLASS, "class3"));
reader = writer.getReader();
writer.close();
searcher = newSearcher(reader);
}
private void prepareTrainedIndexMultiClass() throws Exception {
directory = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), directory);
//class1
addDoc(writer, buildLuceneDocument(ID, "1",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "Name Surname",
TRAINING_CLASS, "class1",
TRAINING_CLASS, "class2"
));
addDoc(writer, buildLuceneDocument(ID, "2",
TITLE, "word1 word1",
CONTENT, "word2 word2",
AUTHOR, "Name Surname",
TRAINING_CLASS, "class3",
TRAINING_CLASS, "class2"
));
addDoc(writer, buildLuceneDocument(ID, "3",
TITLE, "word1 word1 word1",
CONTENT, "word2",
AUTHOR, "Name Surname",
TRAINING_CLASS, "class1",
TRAINING_CLASS, "class2"
));
addDoc(writer, buildLuceneDocument(ID, "4",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "Name Surname",
TRAINING_CLASS, "class1",
TRAINING_CLASS, "class2"
));
//class2
addDoc(writer, buildLuceneDocument(ID, "5",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5",
AUTHOR, "Name1 Surname1",
TRAINING_CLASS, "class6",
TRAINING_CLASS, "class4"
));
addDoc(writer, buildLuceneDocument(ID, "6",
TITLE, "word4 word4",
CONTENT, "word5",
AUTHOR, "Name1 Surname1",
TRAINING_CLASS, "class5",
TRAINING_CLASS, "class4"
));
addDoc(writer, buildLuceneDocument(ID, "7",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 word5",
AUTHOR, "Name1 Surname1",
TRAINING_CLASS, "class6",
TRAINING_CLASS, "class4"
));
addDoc(writer, buildLuceneDocument(ID, "8",
TITLE, "word4",
CONTENT, "word5 word5 word5 word5",
AUTHOR, "Name1 Surname1",
TRAINING_CLASS, "class6",
TRAINING_CLASS, "class4"
));
reader = writer.getReader();
writer.close();
searcher = newSearcher(reader);
}
public static Document buildLuceneDocument(Object... fieldsAndValues) {
Document luceneDoc = new Document();
for (int i=0; i<fieldsAndValues.length; i+=2) {
luceneDoc.add(newTextField((String)fieldsAndValues[i], (String)fieldsAndValues[i+1], Field.Store.YES));
}
return luceneDoc;
}
private int addDoc(RandomIndexWriter writer, Document doc) throws IOException {
writer.addDocument(doc);
return writer.numDocs() - 1;
}
}

View File

@ -16,31 +16,28 @@
*/
package org.apache.solr.update.processor;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.util.Constants;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.client.solrj.impl.BinaryRequestWriter;
import org.apache.solr.client.solrj.request.UpdateRequest;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.common.params.MultiMapSolrParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.params.UpdateParams;
import org.apache.solr.common.util.ContentStream;
import org.apache.solr.common.util.ContentStreamBase;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrCore;
import org.apache.solr.handler.UpdateRequestHandler;
import org.apache.solr.request.LocalSolrQueryRequest;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.request.SolrQueryRequestBase;
import org.apache.solr.response.SolrQueryResponse;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
/**
*
*/
@ -359,21 +356,4 @@ public class SignatureUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
private void addDoc(String doc) throws Exception {
addDoc(doc, chain);
}
static void addDoc(String doc, String chain) throws Exception {
Map<String, String[]> params = new HashMap<>();
MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
params.put(UpdateParams.UPDATE_CHAIN, new String[] { chain });
SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(),
(SolrParams) mmparams) {
};
UpdateRequestHandler handler = new UpdateRequestHandler();
handler.init(null);
ArrayList<ContentStream> streams = new ArrayList<>(2);
streams.add(new ContentStreamBase.StringStream(doc));
req.setContentStreams(streams);
handler.handleRequestBody(req, new SolrQueryResponse());
req.close();
}
}

View File

@ -25,8 +25,6 @@ import org.junit.Test;
import java.util.Map;
import static org.apache.solr.update.processor.SignatureUpdateProcessorFactoryTest.addDoc;
public class TestPartialUpdateDeduplication extends SolrTestCaseJ4 {
@BeforeClass
public static void beforeClass() throws Exception {

View File

@ -83,7 +83,11 @@ import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.common.SolrInputField;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.params.MultiMapSolrParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.params.UpdateParams;
import org.apache.solr.common.util.ContentStream;
import org.apache.solr.common.util.ContentStreamBase;
import org.apache.solr.common.util.ObjectReleaseTracker;
import org.apache.solr.common.util.XML;
import org.apache.solr.core.CoreContainer;
@ -96,7 +100,9 @@ import org.apache.solr.core.SolrXmlConfig;
import org.apache.solr.handler.UpdateRequestHandler;
import org.apache.solr.request.LocalSolrQueryRequest;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.request.SolrQueryRequestBase;
import org.apache.solr.request.SolrRequestHandler;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.SolrIndexSearcher;
@ -1009,6 +1015,22 @@ public abstract class SolrTestCaseJ4 extends LuceneTestCase {
return out.toString();
}
public static void addDoc(String doc, String updateRequestProcessorChain) throws Exception {
Map<String, String[]> params = new HashMap<>();
MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
params.put(UpdateParams.UPDATE_CHAIN, new String[]{updateRequestProcessorChain});
SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(),
(SolrParams) mmparams) {
};
UpdateRequestHandler handler = new UpdateRequestHandler();
handler.init(null);
ArrayList<ContentStream> streams = new ArrayList<>(2);
streams.add(new ContentStreamBase.StringStream(doc));
req.setContentStreams(streams);
handler.handleRequestBody(req, new SolrQueryResponse());
req.close();
}
/**
* Generates an &lt;add&gt;&lt;doc&gt;... XML String with options