Merge remote-tracking branch 'origin/master'

This commit is contained in:
Noble Paul 2016-11-25 08:18:09 +05:30
commit 950ff50032
24 changed files with 1893 additions and 455 deletions

View File

@ -195,9 +195,10 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
Map<BytesRef, Double> classBoosts = new HashMap<>(); // this is a boost based on class ranking positions in topDocs
float maxScore = topDocs.getMaxScore();
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
IndexableField storableField = indexSearcher.doc(scoreDoc.doc).getField(classFieldName);
if (storableField != null) {
BytesRef cl = new BytesRef(storableField.stringValue());
IndexableField[] storableFields = indexSearcher.doc(scoreDoc.doc).getFields(classFieldName);
for (IndexableField singleStorableField : storableFields) {
if (singleStorableField != null) {
BytesRef cl = new BytesRef(singleStorableField.stringValue());
//update count
Integer count = classCounts.get(cl);
if (count != null) {
@ -213,6 +214,7 @@ public class KNearestNeighborClassifier implements Classifier<BytesRef> {
} else {
classBoosts.put(cl, singleBoost);
}
}
}
}
List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();

View File

@ -109,6 +109,7 @@ public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifi
TopDocs knnResults = knnSearch(document);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
Collections.sort(assignedClasses);
max = Math.min(max, assignedClasses.size());
return assignedClasses.subList(0, max);
}
@ -130,15 +131,14 @@ public class KNearestNeighborDocumentClassifier extends KNearestNeighborClassifi
boost = field2boost[1];
}
String[] fieldValues = document.getValues(fieldName);
mlt.setBoost(true); // we want always to use the boost coming from TF * IDF of the term
if (boost != null) {
mlt.setBoost(true);
mlt.setBoostFactor(Float.parseFloat(boost));
mlt.setBoostFactor(Float.parseFloat(boost)); // this is an additional multiplicative boost coming from the field boost
}
mlt.setAnalyzer(field2analyzer.get(fieldName));
for (String fieldContent : fieldValues) {
mltQuery.add(new BooleanClause(mlt.like(fieldName, new StringReader(fieldContent)), BooleanClause.Occur.SHOULD));
}
mlt.setBoost(false);
}
Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*"));
mltQuery.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST));

View File

@ -98,6 +98,13 @@ Upgrade Notes
replaced by corresponding per-second rates viz. "avgRequestsPerSecond", "5minRateRequestsPerSecond"
and "15minRateRequestsPerSecond" for consistency with stats output in other parts of Solr.
* SOLR-9708: You are encouraged to try out the UnifiedHighlighter by setting hl.method=unified and report feedback. It
might become the default in 7.0. It's more efficient/faster than the other highlighters, especially compared to the
original Highlighter. That said, some options aren't supported yet, notably hl.fragsize and
hl.requireFieldMatch=false. It will get more features in time, especially with your input. See HighlightParams.java
for a listing of highlight parameters annotated with which highlighters use them.
hl.useFastVectorHighlighter is now considered deprecated in lieu of hl.method=fastVector.
New Features
----------------------
* SOLR-9293: Solrj client support for hierarchical clusters and other topics
@ -137,6 +144,12 @@ New Features
* SOLR-9721: javabin Tuple parser for streaming and other end points (noble)
* SOLR-9708: Added UnifiedSolrHighlighter, a highlighter adapter for Lucene's UnifiedHighlighter. The adapter is a
derivative of the PostingsSolrHighlighter, supporting mostly the same parameters with some differences.
Introduced "hl.method" parameter which can be set to original|fastVector|postings|unified to pick the highlighter at
runtime without the need to modify solrconfig from the default configuration. hl.useFastVectorHighlighter is now
considered deprecated in lieu of hl.method=fastVector. (Timothy Rodriguez, David Smiley)
Optimizations
----------------------
* SOLR-9704: Facet Module / JSON Facet API: Optimize blockChildren facets that have

View File

@ -16,22 +16,12 @@
*/
package org.apache.solr.uima.processor;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.util.LuceneTestCase.Slow;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.MultiMapSolrParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.params.UpdateParams;
import org.apache.solr.common.util.ContentStream;
import org.apache.solr.common.util.ContentStreamBase;
import org.apache.solr.core.SolrCore;
import org.apache.solr.handler.UpdateRequestHandler;
import org.apache.solr.request.SolrQueryRequestBase;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.uima.processor.SolrUIMAConfiguration.MapField;
import org.apache.solr.update.processor.UpdateRequestProcessor;
import org.apache.solr.update.processor.UpdateRequestProcessorChain;
@ -188,19 +178,4 @@ public class UIMAUpdateRequestProcessorTest extends SolrTestCaseJ4 {
}
}
private void addDoc(String chain, String doc) throws Exception {
Map<String, String[]> params = new HashMap<>();
params.put(UpdateParams.UPDATE_CHAIN, new String[] { chain });
MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(), (SolrParams) mmparams) {
};
UpdateRequestHandler handler = new UpdateRequestHandler();
handler.init(null);
ArrayList<ContentStream> streams = new ArrayList<>(2);
streams.add(new ContentStreamBase.StringStream(doc));
req.setContentStreams(streams);
handler.handleRequestBody(req, new SolrQueryResponse());
}
}

View File

@ -16,6 +16,14 @@
*/
package org.apache.solr.handler.component;
import java.io.IOException;
import java.net.URL;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Stream;
import com.google.common.base.Objects;
import org.apache.lucene.search.Query;
import org.apache.solr.common.SolrException;
@ -29,6 +37,7 @@ import org.apache.solr.core.SolrCore;
import org.apache.solr.highlight.DefaultSolrHighlighter;
import org.apache.solr.highlight.PostingsSolrHighlighter;
import org.apache.solr.highlight.SolrHighlighter;
import org.apache.solr.highlight.UnifiedSolrHighlighter;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.search.QParser;
import org.apache.solr.search.QParserPlugin;
@ -38,9 +47,7 @@ import org.apache.solr.util.SolrPluginUtils;
import org.apache.solr.util.plugin.PluginInfoInitialized;
import org.apache.solr.util.plugin.SolrCoreAware;
import java.io.IOException;
import java.net.URL;
import java.util.List;
import static java.util.stream.Collectors.toMap;
/**
* TODO!
@ -50,13 +57,50 @@ import java.util.List;
*/
public class HighlightComponent extends SearchComponent implements PluginInfoInitialized, SolrCoreAware
{
public static final String COMPONENT_NAME = "highlight";
private PluginInfo info = PluginInfo.EMPTY_INFO;
private SolrHighlighter highlighter;
public enum HighlightMethod {
UNIFIED("unified"),
FAST_VECTOR("fastVector"),
POSTINGS("postings"),
ORIGINAL("original");
private static final Map<String, HighlightMethod> METHODS = Collections.unmodifiableMap(Stream.of(values())
.collect(toMap(HighlightMethod::getMethodName, Function.identity())));
private final String methodName;
HighlightMethod(String method) {
this.methodName = method;
}
public String getMethodName() {
return methodName;
}
public static HighlightMethod parse(String method) {
return METHODS.get(method);
}
}
public static final String COMPONENT_NAME = "highlight";
private PluginInfo info = PluginInfo.EMPTY_INFO;
@Deprecated // DWS: in 7.0 lets restructure the abstractions/relationships
private SolrHighlighter solrConfigHighlighter;
/**
* @deprecated instead depend on {@link #process(ResponseBuilder)} to choose the highlighter based on
* {@link HighlightParams#METHOD}
*/
@Deprecated
public static SolrHighlighter getHighlighter(SolrCore core) {
HighlightComponent hl = (HighlightComponent) core.getSearchComponents().get(HighlightComponent.COMPONENT_NAME);
return hl==null ? null: hl.getHighlighter();
return hl==null ? null: hl.getHighlighter();
}
@Deprecated
public SolrHighlighter getHighlighter() {
return solrConfigHighlighter;
}
@Override
@ -67,7 +111,7 @@ public class HighlightComponent extends SearchComponent implements PluginInfoIni
@Override
public void prepare(ResponseBuilder rb) throws IOException {
SolrParams params = rb.req.getParams();
rb.doHighlights = highlighter.isHighlightingEnabled(params);
rb.doHighlights = solrConfigHighlighter.isHighlightingEnabled(params);
if(rb.doHighlights){
rb.setNeedDocList(true);
String hlq = params.get(HighlightParams.Q);
@ -90,26 +134,28 @@ public class HighlightComponent extends SearchComponent implements PluginInfoIni
if(children.isEmpty()) {
PluginInfo pluginInfo = core.getSolrConfig().getPluginInfo(SolrHighlighter.class.getName()); //TODO deprecated configuration remove later
if (pluginInfo != null) {
highlighter = core.createInitInstance(pluginInfo, SolrHighlighter.class, null, DefaultSolrHighlighter.class.getName());
solrConfigHighlighter = core.createInitInstance(pluginInfo, SolrHighlighter.class, null, DefaultSolrHighlighter.class.getName());
} else {
DefaultSolrHighlighter defHighlighter = new DefaultSolrHighlighter(core);
defHighlighter.init(PluginInfo.EMPTY_INFO);
highlighter = defHighlighter;
solrConfigHighlighter = defHighlighter;
}
} else {
highlighter = core.createInitInstance(children.get(0),SolrHighlighter.class,null, DefaultSolrHighlighter.class.getName());
solrConfigHighlighter = core.createInitInstance(children.get(0),SolrHighlighter.class,null, DefaultSolrHighlighter.class.getName());
}
}
@Override
public void process(ResponseBuilder rb) throws IOException {
if (rb.doHighlights) {
SolrQueryRequest req = rb.req;
SolrParams params = req.getParams();
String[] defaultHighlightFields; //TODO: get from builder by default?
SolrHighlighter highlighter = getHighlighter(params);
String[] defaultHighlightFields; //TODO: get from builder by default?
if (rb.getQparser() != null) {
defaultHighlightFields = rb.getQparser().getDefaultHighlightFields();
} else {
@ -130,14 +176,8 @@ public class HighlightComponent extends SearchComponent implements PluginInfoIni
rb.setHighlightQuery( highlightQuery );
}
}
if(highlightQuery != null) {
boolean rewrite = (highlighter instanceof PostingsSolrHighlighter == false) && !(Boolean.valueOf(params.get(HighlightParams.USE_PHRASE_HIGHLIGHTER, "true")) &&
Boolean.valueOf(params.get(HighlightParams.HIGHLIGHT_MULTI_TERM, "true")));
highlightQuery = rewrite ? highlightQuery.rewrite(req.getSearcher().getIndexReader()) : highlightQuery;
}
// No highlighting if there is no query -- consider q.alt="*:*
// No highlighting if there is no query -- consider q.alt=*:*
if( highlightQuery != null ) {
NamedList sumData = highlighter.doHighlighting(
rb.getResults().docList,
@ -152,6 +192,36 @@ public class HighlightComponent extends SearchComponent implements PluginInfoIni
}
}
protected SolrHighlighter getHighlighter(SolrParams params) {
HighlightMethod method = HighlightMethod.parse(params.get(HighlightParams.METHOD));
if (method == null) {
return solrConfigHighlighter;
}
switch (method) {
case UNIFIED:
if (solrConfigHighlighter instanceof UnifiedSolrHighlighter) {
return solrConfigHighlighter;
}
return new UnifiedSolrHighlighter(); // TODO cache one?
case POSTINGS:
if (solrConfigHighlighter instanceof PostingsSolrHighlighter) {
return solrConfigHighlighter;
}
return new PostingsSolrHighlighter(); // TODO cache one?
case FAST_VECTOR: // fall-through
case ORIGINAL:
if (solrConfigHighlighter instanceof DefaultSolrHighlighter) {
return solrConfigHighlighter;
} else {
throw new SolrException(SolrException.ErrorCode.SERVER_ERROR,
"In order to use " + HighlightParams.METHOD + "=" + method.getMethodName() + " the configured" +
" highlighter in solrconfig must be " + DefaultSolrHighlighter.class);
}
default: throw new AssertionError();
}
}
@Override
public void modifyRequest(ResponseBuilder rb, SearchComponent who, ShardRequest sreq) {
if (!rb.doHighlights) return;
@ -195,10 +265,6 @@ public class HighlightComponent extends SearchComponent implements PluginInfoIni
}
}
public SolrHighlighter getHighlighter() {
return highlighter;
}
////////////////////////////////////////////
/// SolrInfoMBean
////////////////////////////////////////////

View File

@ -66,6 +66,7 @@ import org.apache.solr.common.util.NamedList;
import org.apache.solr.common.util.SimpleOrderedMap;
import org.apache.solr.core.PluginInfo;
import org.apache.solr.core.SolrCore;
import org.apache.solr.handler.component.HighlightComponent;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
@ -373,6 +374,13 @@ public class DefaultSolrHighlighter extends SolrHighlighter implements PluginInf
if (!isHighlightingEnabled(params)) // also returns early if no unique key field
return null;
boolean rewrite = query != null && !(Boolean.valueOf(params.get(HighlightParams.USE_PHRASE_HIGHLIGHTER, "true")) &&
Boolean.valueOf(params.get(HighlightParams.HIGHLIGHT_MULTI_TERM, "true")));
if (rewrite) {
query = query.rewrite(req.getSearcher().getIndexReader());
}
SolrIndexSearcher searcher = req.getSearcher();
IndexSchema schema = searcher.getSchema();
@ -463,8 +471,11 @@ public class DefaultSolrHighlighter extends SolrHighlighter implements PluginInf
* Determines if we should use the FastVectorHighlighter for this field.
*/
protected boolean useFastVectorHighlighter(SolrParams params, SchemaField schemaField) {
boolean useFvhParam = params.getFieldBool(schemaField.getName(), HighlightParams.USE_FVH, false);
if (!useFvhParam) return false;
boolean methodFvh =
HighlightComponent.HighlightMethod.FAST_VECTOR.getMethodName().equals(
params.getFieldParam(schemaField.getName(), HighlightParams.METHOD))
|| params.getFieldBool(schemaField.getName(), HighlightParams.USE_FVH, false);
if (!methodFvh) return false;
boolean termPosOff = schemaField.storeTermPositions() && schemaField.storeTermOffsets();
if (!termPosOff) {
log.warn("Solr will use the standard Highlighter instead of FastVectorHighlighter because the {} field " +

View File

@ -50,8 +50,9 @@ import org.apache.solr.util.plugin.PluginInfoInitialized;
* <p>
* Example configuration:
* <pre class="prettyprint">
* &lt;requestHandler name="standard" class="solr.StandardRequestHandler"&gt;
* &lt;requestHandler name="/select" class="solr.SearchHandler"&gt;
* &lt;lst name="defaults"&gt;
* &lt;str name="hl.method"&gt;postings&lt;/str&gt;
* &lt;int name="hl.snippets"&gt;1&lt;/int&gt;
* &lt;str name="hl.tag.pre"&gt;&amp;lt;em&amp;gt;&lt;/str&gt;
* &lt;str name="hl.tag.post"&gt;&amp;lt;/em&amp;gt;&lt;/str&gt;
@ -71,12 +72,6 @@ import org.apache.solr.util.plugin.PluginInfoInitialized;
* &lt;/lst&gt;
* &lt;/requestHandler&gt;
* </pre>
* ...
* <pre class="prettyprint">
* &lt;searchComponent class="solr.HighlightComponent" name="highlight"&gt;
* &lt;highlighting class="org.apache.solr.highlight.PostingsSolrHighlighter"/&gt;
* &lt;/searchComponent&gt;
* </pre>
* <p>
* Notes:
* <ul>

View File

@ -0,0 +1,365 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.highlight;
import java.io.IOException;
import java.text.BreakIterator;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.document.Document;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.postingshighlight.WholeBreakIterator;
import org.apache.lucene.search.uhighlight.DefaultPassageFormatter;
import org.apache.lucene.search.uhighlight.PassageFormatter;
import org.apache.lucene.search.uhighlight.PassageScorer;
import org.apache.lucene.search.uhighlight.UnifiedHighlighter;
import org.apache.solr.common.params.HighlightParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.common.util.SimpleOrderedMap;
import org.apache.solr.core.PluginInfo;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.request.SolrRequestInfo;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.DocIterator;
import org.apache.solr.search.DocList;
import org.apache.solr.search.SolrIndexSearcher;
import org.apache.solr.util.RTimerTree;
import org.apache.solr.util.plugin.PluginInfoInitialized;
/**
* Highlighter impl that uses {@link UnifiedHighlighter}
* <p>
* Example configuration with default values:
* <pre class="prettyprint">
* &lt;requestHandler name="/select" class="solr.SearchHandler"&gt;
* &lt;lst name="defaults"&gt;
* &lt;str name="hl.method"&gt;unified&lt;/str&gt;
* &lt;int name="hl.snippets"&gt;1&lt;/int&gt;
* &lt;str name="hl.tag.pre"&gt;&amp;lt;em&amp;gt;&lt;/str&gt;
* &lt;str name="hl.tag.post"&gt;&amp;lt;/em&amp;gt;&lt;/str&gt;
* &lt;str name="hl.simple.pre"&gt;&amp;lt;em&amp;gt;&lt;/str&gt;
* &lt;str name="hl.simple.post"&gt;&amp;lt;/em&amp;gt;&lt;/str&gt;
* &lt;str name="hl.tag.ellipsis"&gt;... &lt;/str&gt;
* &lt;bool name="hl.defaultSummary"&gt;true&lt;/bool&gt;
* &lt;str name="hl.encoder"&gt;simple&lt;/str&gt;
* &lt;float name="hl.score.k1"&gt;1.2&lt;/float&gt;
* &lt;float name="hl.score.b"&gt;0.75&lt;/float&gt;
* &lt;float name="hl.score.pivot"&gt;87&lt;/float&gt;
* &lt;str name="hl.bs.language"&gt;&lt;/str&gt;
* &lt;str name="hl.bs.country"&gt;&lt;/str&gt;
* &lt;str name="hl.bs.variant"&gt;&lt;/str&gt;
* &lt;str name="hl.bs.type"&gt;SENTENCE&lt;/str&gt;
* &lt;int name="hl.maxAnalyzedChars"&gt;10000&lt;/int&gt;
* &lt;bool name="hl.highlightMultiTerm"&gt;true&lt;/bool&gt;
* &lt;bool name="hl.usePhraseHighlighter"&gt;true&lt;/bool&gt;
* &lt;int name="hl.cacheFieldValCharsThreshold"&gt;524288&lt;/int&gt;
* &lt;str name="hl.offsetSource"&gt;&lt;/str&gt;
* &lt;/lst&gt;
* &lt;/requestHandler&gt;
* </pre>
* <p>
* Notes:
* <ul>
* <li>hl.q (string) can specify the query
* <li>hl.fl (string) specifies the field list.
* <li>hl.snippets (int) specifies how many snippets to return.
* <li>hl.tag.pre (string) specifies text which appears before a highlighted term.
* <li>hl.tag.post (string) specifies text which appears after a highlighted term.
* <li>hl.simple.pre (string) specifies text which appears before a highlighted term. (prefer hl.tag.pre)
* <li>hl.simple.post (string) specifies text which appears before a highlighted term. (prefer hl.tag.post)
* <li>hl.tag.ellipsis (string) specifies text which joins non-adjacent passages. The default is to retain each
* value in a list without joining them.
* <li>hl.defaultSummary (bool) specifies if a field should have a default summary of the leading text.
* <li>hl.encoder (string) can be 'html' (html escapes content) or 'simple' (no escaping).
* <li>hl.score.k1 (float) specifies bm25 scoring parameter 'k1'
* <li>hl.score.b (float) specifies bm25 scoring parameter 'b'
* <li>hl.score.pivot (float) specifies bm25 scoring parameter 'avgdl'
* <li>hl.bs.type (string) specifies how to divide text into passages: [SENTENCE, LINE, WORD, CHAR, WHOLE]
* <li>hl.bs.language (string) specifies language code for BreakIterator. default is empty string (root locale)
* <li>hl.bs.country (string) specifies country code for BreakIterator. default is empty string (root locale)
* <li>hl.bs.variant (string) specifies country code for BreakIterator. default is empty string (root locale)
* <li>hl.maxAnalyzedChars (int) specifies how many characters at most will be processed in a document for any one field.
* <li>hl.highlightMultiTerm (bool) enables highlighting for range/wildcard/fuzzy/prefix queries at some cost. default is true
* <li>hl.usePhraseHighlighter (bool) enables phrase highlighting. default is true
* <li>hl.cacheFieldValCharsThreshold (int) controls how many characters from a field are cached. default is 524288 (1MB in 2 byte chars)
* <li>hl.offsetSource (string) specifies which offset source to use, prefers postings, but will use what's available if not specified
* </ul>
*
* @lucene.experimental
*/
public class UnifiedSolrHighlighter extends SolrHighlighter implements PluginInfoInitialized {
protected static final String SNIPPET_SEPARATOR = "\u0000";
private static final String[] ZERO_LEN_STR_ARRAY = new String[0];
@Override
public void init(PluginInfo info) {
}
@Override
public NamedList<Object> doHighlighting(DocList docs, Query query, SolrQueryRequest req, String[] defaultFields) throws IOException {
final SolrParams params = req.getParams();
// if highlighting isn't enabled, then why call doHighlighting?
if (!isHighlightingEnabled(params))
return null;
int[] docIDs = toDocIDs(docs);
// fetch the unique keys
String[] keys = getUniqueKeys(req.getSearcher(), docIDs);
// query-time parameters
String[] fieldNames = getHighlightFields(query, req, defaultFields);
int maxPassages[] = new int[fieldNames.length];
for (int i = 0; i < fieldNames.length; i++) {
maxPassages[i] = params.getFieldInt(fieldNames[i], HighlightParams.SNIPPETS, 1);
}
UnifiedHighlighter highlighter = getHighlighter(req);
Map<String, String[]> snippets = highlighter.highlightFields(fieldNames, query, docIDs, maxPassages);
return encodeSnippets(keys, fieldNames, snippets);
}
/**
* Creates an instance of the Lucene {@link UnifiedHighlighter}. Provided for subclass extension so that
* a subclass can return a subclass of {@link SolrExtendedUnifiedHighlighter}.
*/
protected UnifiedHighlighter getHighlighter(SolrQueryRequest req) {
return new SolrExtendedUnifiedHighlighter(req);
}
/**
* Encodes the resulting snippets into a namedlist
*
* @param keys the document unique keys
* @param fieldNames field names to highlight in the order
* @param snippets map from field name to snippet array for the docs
* @return encoded namedlist of summaries
*/
protected NamedList<Object> encodeSnippets(String[] keys, String[] fieldNames, Map<String, String[]> snippets) {
NamedList<Object> list = new SimpleOrderedMap<>();
for (int i = 0; i < keys.length; i++) {
NamedList<Object> summary = new SimpleOrderedMap<>();
for (String field : fieldNames) {
String snippet = snippets.get(field)[i];
if (snippet == null) {
//TODO reuse logic of DefaultSolrHighlighter.alternateField
summary.add(field, ZERO_LEN_STR_ARRAY);
} else {
// we used a special snippet separator char and we can now split on it.
summary.add(field, snippet.split(SNIPPET_SEPARATOR));
}
}
list.add(keys[i], summary);
}
return list;
}
/**
* Converts solr's DocList to the int[] docIDs
*/
protected int[] toDocIDs(DocList docs) {
int[] docIDs = new int[docs.size()];
DocIterator iterator = docs.iterator();
for (int i = 0; i < docIDs.length; i++) {
if (!iterator.hasNext()) {
throw new AssertionError();
}
docIDs[i] = iterator.nextDoc();
}
if (iterator.hasNext()) {
throw new AssertionError();
}
return docIDs;
}
/**
* Retrieves the unique keys for the topdocs to key the results
*/
protected String[] getUniqueKeys(SolrIndexSearcher searcher, int[] docIDs) throws IOException {
IndexSchema schema = searcher.getSchema();
SchemaField keyField = schema.getUniqueKeyField();
if (keyField != null) {
Set<String> selector = Collections.singleton(keyField.getName());
String[] uniqueKeys = new String[docIDs.length];
for (int i = 0; i < docIDs.length; i++) {
int docid = docIDs[i];
Document doc = searcher.doc(docid, selector);
String id = schema.printableUniqueKey(doc);
uniqueKeys[i] = id;
}
return uniqueKeys;
} else {
return new String[docIDs.length];
}
}
/**
* From {@link #getHighlighter(org.apache.solr.request.SolrQueryRequest)}.
*/
protected static class SolrExtendedUnifiedHighlighter extends UnifiedHighlighter {
protected final SolrParams params;
protected final IndexSchema schema;
protected final RTimerTree loadFieldValuesTimer;
public SolrExtendedUnifiedHighlighter(SolrQueryRequest req) {
super(req.getSearcher(), req.getSchema().getIndexAnalyzer());
this.params = req.getParams();
this.schema = req.getSchema();
this.setMaxLength(
params.getInt(HighlightParams.MAX_CHARS, UnifiedHighlighter.DEFAULT_MAX_LENGTH));
this.setCacheFieldValCharsThreshold(
params.getInt(HighlightParams.CACHE_FIELD_VAL_CHARS_THRESHOLD, DEFAULT_CACHE_CHARS_THRESHOLD));
// SolrRequestInfo is a thread-local singleton providing access to the ResponseBuilder to code that
// otherwise can't get it in a nicer way.
SolrQueryRequest request = SolrRequestInfo.getRequestInfo().getReq();
final RTimerTree timerTree;
if (request.getRequestTimer() != null) { //It may be null if not used in a search context.
timerTree = request.getRequestTimer();
} else {
timerTree = new RTimerTree(); // since null checks are annoying
}
loadFieldValuesTimer = timerTree.sub("loadFieldValues"); // we assume a new timer, state of STARTED
loadFieldValuesTimer.pause(); // state of PAUSED now with about zero time. Will fail if state isn't STARTED.
}
@Override
protected OffsetSource getOffsetSource(String field) {
String sourceStr = params.getFieldParam(field, HighlightParams.OFFSET_SOURCE);
if (sourceStr != null) {
return OffsetSource.valueOf(sourceStr.toUpperCase(Locale.ROOT));
} else {
return super.getOffsetSource(field);
}
}
@Override
public int getMaxNoHighlightPassages(String field) {
boolean defaultSummary = params.getFieldBool(field, HighlightParams.DEFAULT_SUMMARY, false);
if (defaultSummary) {
return -1;// signifies return first hl.snippets passages worth of the content
} else {
return 0;// will return null
}
}
@Override
protected PassageFormatter getFormatter(String fieldName) {
String preTag = params.getFieldParam(fieldName, HighlightParams.TAG_PRE,
params.getFieldParam(fieldName, HighlightParams.SIMPLE_PRE, "<em>")
);
String postTag = params.getFieldParam(fieldName, HighlightParams.TAG_POST,
params.getFieldParam(fieldName, HighlightParams.SIMPLE_POST, "</em>")
);
String ellipsis = params.getFieldParam(fieldName, HighlightParams.TAG_ELLIPSIS, SNIPPET_SEPARATOR);
String encoder = params.getFieldParam(fieldName, HighlightParams.ENCODER, "simple");
return new DefaultPassageFormatter(preTag, postTag, ellipsis, "html".equals(encoder));
}
@Override
protected PassageScorer getScorer(String fieldName) {
float k1 = params.getFieldFloat(fieldName, HighlightParams.SCORE_K1, 1.2f);
float b = params.getFieldFloat(fieldName, HighlightParams.SCORE_B, 0.75f);
float pivot = params.getFieldFloat(fieldName, HighlightParams.SCORE_PIVOT, 87f);
return new PassageScorer(k1, b, pivot);
}
@Override
protected BreakIterator getBreakIterator(String field) {
String language = params.getFieldParam(field, HighlightParams.BS_LANGUAGE);
String country = params.getFieldParam(field, HighlightParams.BS_COUNTRY);
String variant = params.getFieldParam(field, HighlightParams.BS_VARIANT);
Locale locale = parseLocale(language, country, variant);
String type = params.getFieldParam(field, HighlightParams.BS_TYPE);
return parseBreakIterator(type, locale);
}
/**
* parse a break iterator type for the specified locale
*/
protected BreakIterator parseBreakIterator(String type, Locale locale) {
if (type == null || "SENTENCE".equals(type)) {
return BreakIterator.getSentenceInstance(locale);
} else if ("LINE".equals(type)) {
return BreakIterator.getLineInstance(locale);
} else if ("WORD".equals(type)) {
return BreakIterator.getWordInstance(locale);
} else if ("CHARACTER".equals(type)) {
return BreakIterator.getCharacterInstance(locale);
} else if ("WHOLE".equals(type)) {
return new WholeBreakIterator();
} else {
throw new IllegalArgumentException("Unknown " + HighlightParams.BS_TYPE + ": " + type);
}
}
/**
* parse a locale from a language+country+variant spec
*/
protected Locale parseLocale(String language, String country, String variant) {
if (language == null && country == null && variant == null) {
return Locale.ROOT;
} else if (language == null) {
throw new IllegalArgumentException("language is required if country or variant is specified");
} else if (country == null && variant != null) {
throw new IllegalArgumentException("To specify variant, country is required");
} else if (country != null && variant != null) {
return new Locale(language, country, variant);
} else if (country != null) {
return new Locale(language, country);
} else {
return new Locale(language);
}
}
@Override
protected List<CharSequence[]> loadFieldValues(String[] fields, DocIdSetIterator docIter, int
cacheCharsThreshold) throws IOException {
// Time loading field values. It can be an expensive part of highlighting.
loadFieldValuesTimer.resume();
try {
return super.loadFieldValues(fields, docIter, cacheCharsThreshold);
} finally {
loadFieldValuesTimer.pause(); // note: doesn't need to be "stopped"; pause is fine.
}
}
@Override
protected boolean shouldHandleMultiTermQuery(String field) {
return params.getFieldBool(field, HighlightParams.HIGHLIGHT_MULTI_TERM, true);
}
@Override
protected boolean shouldHighlightPhrasesStrictly(String field) {
return params.getFieldBool(field, HighlightParams.USE_PHRASE_HIGHLIGHTER, true);
}
}
}

View File

@ -19,6 +19,7 @@ package org.apache.solr.update.processor;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
@ -33,6 +34,7 @@ import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.update.AddUpdateCommand;
import org.apache.solr.update.processor.ClassificationUpdateProcessorFactory.Algorithm;
/**
* This Class is a Request Update Processor to classify the document in input and add a field
@ -42,43 +44,54 @@ import org.apache.solr.update.AddUpdateCommand;
class ClassificationUpdateProcessor
extends UpdateRequestProcessor {
private String classFieldName; // the field to index the assigned class
private final String trainingClassField;
private final String predictedClassField;
private final int maxOutputClasses;
private DocumentClassifier<BytesRef> classifier;
/**
* Sole constructor
*
* @param inputFieldNames fields to be used as classifier's inputs
* @param classFieldName field to be used as classifier's output
* @param minDf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minDocFreq}, in case algorithm is {@code "knn"}
* @param minTf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minTermFreq}, in case algorithm is {@code "knn"}
* @param k setting for k nearest neighbors to analyze, in case algorithm is {@code "knn"}
* @param algorithm the name of the classifier to use
* @param classificationParams classification advanced params
* @param next next update processor in the chain
* @param indexReader index reader
* @param schema schema
*/
public ClassificationUpdateProcessor(String[] inputFieldNames, String classFieldName, int minDf, int minTf, int k, String algorithm,
UpdateRequestProcessor next, IndexReader indexReader, IndexSchema schema) {
public ClassificationUpdateProcessor(ClassificationUpdateProcessorParams classificationParams, UpdateRequestProcessor next, IndexReader indexReader, IndexSchema schema) {
super(next);
this.classFieldName = classFieldName;
Map<String, Analyzer> field2analyzer = new HashMap<String, Analyzer>();
this.trainingClassField = classificationParams.getTrainingClassField();
this.predictedClassField = classificationParams.getPredictedClassField();
this.maxOutputClasses = classificationParams.getMaxPredictedClasses();
String[] inputFieldNamesWithBoost = classificationParams.getInputFieldNames();
Algorithm classificationAlgorithm = classificationParams.getAlgorithm();
Map<String, Analyzer> field2analyzer = new HashMap<>();
String[] inputFieldNames = this.removeBoost(inputFieldNamesWithBoost);
for (String fieldName : inputFieldNames) {
SchemaField fieldFromSolrSchema = schema.getField(fieldName);
Analyzer indexAnalyzer = fieldFromSolrSchema.getType().getQueryAnalyzer();
field2analyzer.put(fieldName, indexAnalyzer);
}
switch (algorithm) {
case "knn":
classifier = new KNearestNeighborDocumentClassifier(indexReader, null, null, k, minDf, minTf, classFieldName, field2analyzer, inputFieldNames);
switch (classificationAlgorithm) {
case KNN:
classifier = new KNearestNeighborDocumentClassifier(indexReader, null, classificationParams.getTrainingFilterQuery(), classificationParams.getK(), classificationParams.getMinDf(), classificationParams.getMinTf(), trainingClassField, field2analyzer, inputFieldNamesWithBoost);
break;
case "bayes":
classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, classFieldName, field2analyzer, inputFieldNames);
case BAYES:
classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, trainingClassField, field2analyzer, inputFieldNamesWithBoost);
break;
}
}
private String[] removeBoost(String[] inputFieldNamesWithBoost) {
String[] inputFieldNames = new String[inputFieldNamesWithBoost.length];
for (int i = 0; i < inputFieldNamesWithBoost.length; i++) {
String singleFieldNameWithBoost = inputFieldNamesWithBoost[i];
String[] fieldName2boost = singleFieldNameWithBoost.split("\\^");
inputFieldNames[i] = fieldName2boost[0];
}
return inputFieldNames;
}
/**
* @param cmd the update command in input containing the Document to classify
* @throws IOException If there is a low-level I/O error
@ -89,12 +102,14 @@ class ClassificationUpdateProcessor
SolrInputDocument doc = cmd.getSolrInputDocument();
Document luceneDocument = cmd.getLuceneDocument();
String assignedClass;
Object documentClass = doc.getFieldValue(classFieldName);
Object documentClass = doc.getFieldValue(trainingClassField);
if (documentClass == null) {
ClassificationResult<BytesRef> classificationResult = classifier.assignClass(luceneDocument);
if (classificationResult != null) {
assignedClass = classificationResult.getAssignedClass().utf8ToString();
doc.addField(classFieldName, assignedClass);
List<ClassificationResult<BytesRef>> assignedClassifications = classifier.getClasses(luceneDocument, maxOutputClasses);
if (assignedClassifications != null) {
for (ClassificationResult<BytesRef> singleClassification : assignedClassifications) {
assignedClass = singleClassification.getAssignedClass().utf8ToString();
doc.addField(predictedClassField, assignedClass);
}
}
}
super.processAdd(cmd);

View File

@ -18,12 +18,18 @@
package org.apache.solr.update.processor;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Query;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.common.util.SuppressForbidden;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.search.LuceneQParser;
import org.apache.solr.search.SyntaxError;
import static org.apache.solr.update.processor.ClassificationUpdateProcessorFactory.Algorithm.KNN;
/**
* This class implements an UpdateProcessorFactory for the Classification Update Processor.
@ -33,49 +39,68 @@ public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessor
// Update Processor Config params
private static final String INPUT_FIELDS_PARAM = "inputFields";
private static final String CLASS_FIELD_PARAM = "classField";
private static final String TRAINING_CLASS_FIELD_PARAM = "classField";
private static final String PREDICTED_CLASS_FIELD_PARAM = "predictedClassField";
private static final String MAX_CLASSES_TO_ASSIGN_PARAM = "predictedClass.maxCount";
private static final String ALGORITHM_PARAM = "algorithm";
private static final String KNN_MIN_TF_PARAM = "knn.minTf";
private static final String KNN_MIN_DF_PARAM = "knn.minDf";
private static final String KNN_K_PARAM = "knn.k";
private static final String KNN_FILTER_QUERY = "knn.filterQuery";
public enum Algorithm {KNN, BAYES}
//Update Processor Defaults
private static final int DEFAULT_MAX_CLASSES_TO_ASSIGN = 1;
private static final int DEFAULT_MIN_TF = 1;
private static final int DEFAULT_MIN_DF = 1;
private static final int DEFAULT_K = 10;
private static final String DEFAULT_ALGORITHM = "knn";
private static final Algorithm DEFAULT_ALGORITHM = KNN;
private String[] inputFieldNames; // the array of fields to be sent to the Classifier
private String classFieldName; // the field containing the class for the Document
private String algorithm; // the Classification Algorithm to use - currently 'knn' or 'bayes'
private int minTf; // knn specific - the minimum Term Frequency for considering a term
private int minDf; // knn specific - the minimum Document Frequency for considering a term
private int k; // knn specific - thw window of top results to evaluate, when assigning the class
private SolrParams params;
private ClassificationUpdateProcessorParams classificationParams;
@SuppressForbidden(reason = "Need toUpperCase to match algorithm enum value")
@Override
public void init(final NamedList args) {
if (args != null) {
SolrParams params = SolrParams.toSolrParams(args);
params = SolrParams.toSolrParams(args);
classificationParams = new ClassificationUpdateProcessorParams();
String fieldNames = params.get(INPUT_FIELDS_PARAM);// must be a comma separated list of fields
checkNotNull(INPUT_FIELDS_PARAM, fieldNames);
inputFieldNames = fieldNames.split("\\,");
classificationParams.setInputFieldNames(fieldNames.split("\\,"));
classFieldName = params.get(CLASS_FIELD_PARAM);
checkNotNull(CLASS_FIELD_PARAM, classFieldName);
String trainingClassField = (params.get(TRAINING_CLASS_FIELD_PARAM));
checkNotNull(TRAINING_CLASS_FIELD_PARAM, trainingClassField);
classificationParams.setTrainingClassField(trainingClassField);
algorithm = params.get(ALGORITHM_PARAM);
if (algorithm == null)
algorithm = DEFAULT_ALGORITHM;
String predictedClassField = (params.get(PREDICTED_CLASS_FIELD_PARAM));
if (predictedClassField == null || predictedClassField.isEmpty()) {
predictedClassField = trainingClassField;
}
classificationParams.setPredictedClassField(predictedClassField);
minTf = getIntParam(params, KNN_MIN_TF_PARAM, DEFAULT_MIN_TF);
minDf = getIntParam(params, KNN_MIN_DF_PARAM, DEFAULT_MIN_DF);
k = getIntParam(params, KNN_K_PARAM, DEFAULT_K);
classificationParams.setMaxPredictedClasses(getIntParam(params, MAX_CLASSES_TO_ASSIGN_PARAM, DEFAULT_MAX_CLASSES_TO_ASSIGN));
String algorithmString = params.get(ALGORITHM_PARAM);
Algorithm classificationAlgorithm;
try {
if (algorithmString == null || Algorithm.valueOf(algorithmString.toUpperCase()) == null) {
classificationAlgorithm = DEFAULT_ALGORITHM;
} else {
classificationAlgorithm = Algorithm.valueOf(algorithmString.toUpperCase());
}
} catch (IllegalArgumentException e) {
throw new SolrException
(SolrException.ErrorCode.SERVER_ERROR,
"Classification UpdateProcessor Algorithm: '" + algorithmString + "' not supported");
}
classificationParams.setAlgorithm(classificationAlgorithm);
classificationParams.setMinTf(getIntParam(params, KNN_MIN_TF_PARAM, DEFAULT_MIN_TF));
classificationParams.setMinDf(getIntParam(params, KNN_MIN_DF_PARAM, DEFAULT_MIN_DF));
classificationParams.setK(getIntParam(params, KNN_K_PARAM, DEFAULT_K));
}
}
@ -108,116 +133,34 @@ public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessor
@Override
public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
String trainingFilterQueryString = (params.get(KNN_FILTER_QUERY));
try {
if (trainingFilterQueryString != null && !trainingFilterQueryString.isEmpty()) {
Query trainingFilterQuery = this.parseFilterQuery(trainingFilterQueryString, params, req);
classificationParams.setTrainingFilterQuery(trainingFilterQuery);
}
} catch (SyntaxError | RuntimeException syntaxError) {
throw new SolrException
(SolrException.ErrorCode.SERVER_ERROR,
"Classification UpdateProcessor Training Filter Query: '" + trainingFilterQueryString + "' is not supported", syntaxError);
}
IndexSchema schema = req.getSchema();
IndexReader indexReader = req.getSearcher().getIndexReader();
return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, indexReader, schema);
return new ClassificationUpdateProcessor(classificationParams, next, indexReader, schema);
}
/**
* get field names used as classifier's inputs
*
* @return the input field names
*/
public String[] getInputFieldNames() {
return inputFieldNames;
private Query parseFilterQuery(String trainingFilterQueryString, SolrParams params, SolrQueryRequest req) throws SyntaxError {
LuceneQParser parser = new LuceneQParser(trainingFilterQueryString, null, params, req);
return parser.parse();
}
/**
* set field names used as classifier's inputs
*
* @param inputFieldNames the input field names
*/
public void setInputFieldNames(String[] inputFieldNames) {
this.inputFieldNames = inputFieldNames;
public ClassificationUpdateProcessorParams getClassificationParams() {
return classificationParams;
}
/**
* get field names used as classifier's output
*
* @return the output field name
*/
public String getClassFieldName() {
return classFieldName;
}
/**
* set field names used as classifier's output
*
* @param classFieldName the output field name
*/
public void setClassFieldName(String classFieldName) {
this.classFieldName = classFieldName;
}
/**
* get the name of the classifier algorithm used
*
* @return the classifier algorithm used
*/
public String getAlgorithm() {
return algorithm;
}
/**
* set the name of the classifier algorithm used
*
* @param algorithm the classifier algorithm used
*/
public void setAlgorithm(String algorithm) {
this.algorithm = algorithm;
}
/**
* get the min term frequency value to be used in case algorithm is {@code "knn"}
*
* @return the min term frequency
*/
public int getMinTf() {
return minTf;
}
/**
* set the min term frequency value to be used in case algorithm is {@code "knn"}
*
* @param minTf the min term frequency
*/
public void setMinTf(int minTf) {
this.minTf = minTf;
}
/**
* get the min document frequency value to be used in case algorithm is {@code "knn"}
*
* @return the min document frequency
*/
public int getMinDf() {
return minDf;
}
/**
* set the min document frequency value to be used in case algorithm is {@code "knn"}
*
* @param minDf the min document frequency
*/
public void setMinDf(int minDf) {
this.minDf = minDf;
}
/**
* get the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"}
*
* @return the no. of neighbors to analyze
*/
public int getK() {
return k;
}
/**
* set the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"}
*
* @param k the no. of neighbors to analyze
*/
public void setK(int k) {
this.k = k;
public void setClassificationParams(ClassificationUpdateProcessorParams classificationParams) {
this.classificationParams = classificationParams;
}
}

View File

@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.update.processor;
import org.apache.lucene.search.Query;
public class ClassificationUpdateProcessorParams {
private String[] inputFieldNames; // the array of fields to be sent to the Classifier
private Query trainingFilterQuery; // a filter query to reduce the training set to a subset
private String trainingClassField; // the field containing the class for the Document
private String predictedClassField; // the field that will contain the predicted class
private int maxPredictedClasses; // the max number of classes to assign
private ClassificationUpdateProcessorFactory.Algorithm algorithm; // the Classification Algorithm to use - currently 'knn' or 'bayes'
private int minTf; // knn specific - the minimum Term Frequency for considering a term
private int minDf; // knn specific - the minimum Document Frequency for considering a term
private int k; // knn specific - thw window of top results to evaluate, when assigning the class
public String[] getInputFieldNames() {
return inputFieldNames;
}
public void setInputFieldNames(String[] inputFieldNames) {
this.inputFieldNames = inputFieldNames;
}
public Query getTrainingFilterQuery() {
return trainingFilterQuery;
}
public void setTrainingFilterQuery(Query trainingFilterQuery) {
this.trainingFilterQuery = trainingFilterQuery;
}
public String getTrainingClassField() {
return trainingClassField;
}
public void setTrainingClassField(String trainingClassField) {
this.trainingClassField = trainingClassField;
}
public String getPredictedClassField() {
return predictedClassField;
}
public void setPredictedClassField(String predictedClassField) {
this.predictedClassField = predictedClassField;
}
public int getMaxPredictedClasses() {
return maxPredictedClasses;
}
public void setMaxPredictedClasses(int maxPredictedClasses) {
this.maxPredictedClasses = maxPredictedClasses;
}
public ClassificationUpdateProcessorFactory.Algorithm getAlgorithm() {
return algorithm;
}
public void setAlgorithm(ClassificationUpdateProcessorFactory.Algorithm algorithm) {
this.algorithm = algorithm;
}
public int getMinTf() {
return minTf;
}
public void setMinTf(int minTf) {
this.minTf = minTf;
}
public int getMinDf() {
return minDf;
}
public void setMinDf(int minDf) {
this.minDf = minDf;
}
public int getK() {
return k;
}
public void setK(int k) {
this.k = k;
}
}

View File

@ -0,0 +1,64 @@
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<?xml version="1.0" ?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<!-- Test schema file for PostingsHighlighter -->
<schema name="unifiedhighlight" version="1.0">
<fieldType name="int" class="solr.TrieIntField" precisionStep="0" omitNorms="true" positionIncrementGap="0"/>
<!-- basic text field: no offsets! -->
<fieldType name="text" class="solr.TextField">
<analyzer>
<tokenizer class="solr.MockTokenizerFactory"/>
<filter class="solr.LowerCaseFilterFactory"/>
</analyzer>
</fieldType>
<!-- text field with offsets -->
<fieldType name="text_offsets" class="solr.TextField" storeOffsetsWithPositions="true">
<analyzer>
<tokenizer class="solr.MockTokenizerFactory"/>
<filter class="solr.LowerCaseFilterFactory"/>
</analyzer>
</fieldType>
<field name="id" type="int" indexed="true" stored="true" multiValued="false" required="false"/>
<field name="text" type="text_offsets" indexed="true" stored="true"/>
<field name="text2" type="text" indexed="true" stored="true"/>
<field name="text3" type="text_offsets" indexed="true" stored="true"/>
<defaultSearchField>text</defaultSearchField>
<uniqueKey>id</uniqueKey>
</schema>

View File

@ -47,6 +47,21 @@
<str name="knn.minTf">1</str>
<str name="knn.minDf">1</str>
<str name="knn.k">5</str>
<str name="knn.filterQuery">cat:(class1 OR class2)</str>
</processor>
<processor class="solr.RunUpdateProcessorFactory"/>
</updateRequestProcessorChain>
<updateRequestProcessorChain name="classification-unsupported-filterQuery">
<processor class="solr.ClassificationUpdateProcessorFactory">
<str name="inputFields">title,content,author</str>
<str name="classField">cat</str>
<!-- Knn algorithm specific-->
<str name="algorithm">knn</str>
<str name="knn.minTf">1</str>
<str name="knn.minDf">1</str>
<str name="knn.k">5</str>
<str name="knn.filterQuery">not valid ( lucene query</str>
</processor>
<processor class="solr.RunUpdateProcessorFactory"/>
</updateRequestProcessorChain>

View File

@ -70,7 +70,12 @@ public class FastVectorHighlighterTest extends SolrTestCaseJ4 {
args.put("hl", "true");
args.put("hl.fl", "tv_text");
args.put("hl.snippets", "2");
args.put("hl.useFastVectorHighlighter", "true");
args.put("hl.tag.pre", "<fvpre>"); //... and let post default to </em>. This is just a test.
if (random().nextBoolean()) {
args.put("hl.useFastVectorHighlighter", "true"); // old way
} else {
args.put("hl.method", "fastVector"); // the new way
}
TestHarness.LocalRequestFactory sumLRF = h.getRequestFactory(
"standard",0,200,args);
@ -81,7 +86,7 @@ public class FastVectorHighlighterTest extends SolrTestCaseJ4 {
assertQ("Basic summarization",
sumLRF.makeRequest("tv_text:vector"),
"//lst[@name='highlighting']/lst[@name='1']",
"//lst[@name='1']/arr[@name='tv_text']/str[.='basic fast <em>vector</em> highlighter test']"
"//lst[@name='1']/arr[@name='tv_text']/str[.='basic fast <fvpre>vector</em> highlighter test']"
);
}
}

View File

@ -43,10 +43,6 @@ import org.junit.After;
import org.junit.BeforeClass;
import org.junit.Test;
/**
* Tests some basic functionality of Solr while demonstrating good
* Best Practices for using AbstractSolrTestCase
*/
public class HighlighterTest extends SolrTestCaseJ4 {
private static String LONG_TEXT = "a long days night this should be a piece of text which is is is is is is is is is is is is is is is is is is is " +
@ -90,6 +86,25 @@ public class HighlighterTest extends SolrTestCaseJ4 {
assertTrue(regex instanceof RegexFragmenter);
}
@Test
public void testMethodPostings() {
String field = "t_text";
assertU(adoc(field, LONG_TEXT,
"id", "1"));
assertU(commit());
try {
assertQ("Tried PostingsSolrHighlighter but failed due to offsets not in postings",
req("q", "long", "hl.method", "postings", "df", field, "hl", "true"));
fail("Did not encounter exception for no offsets");
} catch (Exception e) {
assertTrue("Cause should be illegal argument", e.getCause() instanceof IllegalArgumentException);
assertTrue("Should warn no offsets", e.getCause().getMessage().contains("indexed without offsets"));
}
// note: the default schema.xml has no offsets in postings to test the PostingsHighlighter. Leave that for another
// test class.
}
@Test
public void testMergeContiguous() throws Exception {
HashMap<String,String> args = new HashMap<>();
@ -99,6 +114,7 @@ public class HighlighterTest extends SolrTestCaseJ4 {
args.put(HighlightParams.SNIPPETS, String.valueOf(4));
args.put(HighlightParams.FRAGSIZE, String.valueOf(40));
args.put(HighlightParams.MERGE_CONTIGUOUS_FRAGMENTS, "true");
args.put(HighlightParams.METHOD, "original"); // test works; no complaints
TestHarness.LocalRequestFactory sumLRF = h.getRequestFactory(
"standard", 0, 200, args);
String input = "this is some long text. It has the word long in many places. In fact, it has long on some different fragments. " +
@ -763,7 +779,7 @@ public class HighlighterTest extends SolrTestCaseJ4 {
);
// Prove fallback highlighting works also with FVH
args.put("hl.useFastVectorHighlighter", "true");
args.put("hl.method", "fastVector");
args.put("hl.tag.pre", "<fvhpre>");
args.put("hl.tag.post", "</fvhpost>");
args.put("f.t_text.hl.maxAlternateFieldLength", "18");

View File

@ -52,7 +52,7 @@ public class TestPostingsSolrHighlighter extends SolrTestCaseJ4 {
public void testSimple() {
assertQ("simplest test",
req("q", "text:document", "sort", "id asc", "hl", "true"),
req("q", "text:document", "sort", "id asc", "hl", "true", "hl.method", "postings"), // test hl.method is happy too
"count(//lst[@name='highlighting']/*)=2",
"//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/str='<em>document</em> one'",
"//lst[@name='highlighting']/lst[@name='102']/arr[@name='text']/str='second <em>document</em>'");

View File

@ -0,0 +1,229 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.highlight;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.schema.IndexSchema;
import org.junit.BeforeClass;
/** Tests for the UnifiedHighlighter Solr plugin **/
public class TestUnifiedSolrHighlighter extends SolrTestCaseJ4 {
@BeforeClass
public static void beforeClass() throws Exception {
initCore("solrconfig-basic.xml", "schema-unifiedhighlight.xml");
// test our config is sane, just to be sure:
// 'text' and 'text3' should have offsets, 'text2' should not
IndexSchema schema = h.getCore().getLatestSchema();
assertTrue(schema.getField("text").storeOffsetsWithPositions());
assertTrue(schema.getField("text3").storeOffsetsWithPositions());
assertFalse(schema.getField("text2").storeOffsetsWithPositions());
}
@Override
public void setUp() throws Exception {
super.setUp();
clearIndex();
assertU(adoc("text", "document one", "text2", "document one", "text3", "crappy document", "id", "101"));
assertU(adoc("text", "second document", "text2", "second document", "text3", "crappier document", "id", "102"));
assertU(commit());
}
public static SolrQueryRequest req(String... params) {
return SolrTestCaseJ4.req(params, "hl.method", "unified");
}
public void testSimple() {
assertQ("simplest test",
req("q", "text:document", "sort", "id asc", "hl", "true"),
"count(//lst[@name='highlighting']/*)=2",
"//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/str='<em>document</em> one'",
"//lst[@name='highlighting']/lst[@name='102']/arr[@name='text']/str='second <em>document</em>'");
}
public void testImpossibleOffsetSource() {
try {
assertQ("impossible offset source",
req("q", "text2:document", "hl.offsetSource", "postings", "hl.fl", "text2", "sort", "id asc", "hl", "true"),
"count(//lst[@name='highlighting']/*)=2",
"//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/str='<em>document</em> one'",
"//lst[@name='highlighting']/lst[@name='102']/arr[@name='text']/str='second <em>document</em>'");
fail("Did not encounter exception for no offsets");
} catch (Exception e) {
assertTrue("Cause should be illegal argument", e.getCause() instanceof IllegalArgumentException);
assertTrue("Should warn no offsets", e.getCause().getMessage().contains("indexed without offsets"));
}
}
public void testMultipleSnippetsReturned() {
clearIndex();
assertU(adoc("text", "Document snippet one. Intermediate sentence. Document snippet two.",
"text2", "document one", "text3", "crappy document", "id", "101"));
assertU(commit());
assertQ("multiple snippets test",
req("q", "text:document", "sort", "id asc", "hl", "true", "hl.snippets", "2", "hl.bs.type", "SENTENCE"),
"count(//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/*)=2",
"//lst[@name='highlighting']/lst[@name='101']/arr/str[1]='<em>Document</em> snippet one. '",
"//lst[@name='highlighting']/lst[@name='101']/arr/str[2]='<em>Document</em> snippet two.'");
}
public void testStrictPhrasesEnabledByDefault() {
clearIndex();
assertU(adoc("text", "Strict phrases should be enabled for phrases",
"text2", "document one", "text3", "crappy document", "id", "101"));
assertU(commit());
assertQ("strict phrase handling",
req("q", "text:\"strict phrases\"", "sort", "id asc", "hl", "true"),
"count(//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/*)=1",
"//lst[@name='highlighting']/lst[@name='101']/arr/str[1]='<em>Strict</em> <em>phrases</em> should be enabled for phrases'");
}
public void testStrictPhrasesCanBeDisabled() {
clearIndex();
assertU(adoc("text", "Strict phrases should be disabled for phrases",
"text2", "document one", "text3", "crappy document", "id", "101"));
assertU(commit());
assertQ("strict phrase handling",
req("q", "text:\"strict phrases\"", "sort", "id asc", "hl", "true", "hl.usePhraseHighlighter", "false"),
"count(//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/*)=1",
"//lst[@name='highlighting']/lst[@name='101']/arr/str[1]='<em>Strict</em> <em>phrases</em> should be disabled for <em>phrases</em>'");
}
public void testMultiTermQueryEnabledByDefault() {
clearIndex();
assertU(adoc("text", "Aviary Avenue document",
"text2", "document one", "text3", "crappy document", "id", "101"));
assertU(commit());
assertQ("multi term query handling",
req("q", "text:av*", "sort", "id asc", "hl", "true"),
"count(//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/*)=1",
"//lst[@name='highlighting']/lst[@name='101']/arr/str[1]='<em>Aviary</em> <em>Avenue</em> document'");
}
public void testMultiTermQueryCanBeDisabled() {
clearIndex();
assertU(adoc("text", "Aviary Avenue document",
"text2", "document one", "text3", "crappy document", "id", "101"));
assertU(commit());
assertQ("multi term query handling",
req("q", "text:av*", "sort", "id asc", "hl", "true", "hl.highlightMultiTerm", "false"),
"count(//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/*)=0");
}
public void testPagination() {
assertQ("pagination test",
req("q", "text:document", "sort", "id asc", "hl", "true", "rows", "1", "start", "1"),
"count(//lst[@name='highlighting']/*)=1",
"//lst[@name='highlighting']/lst[@name='102']/arr[@name='text']/str='second <em>document</em>'");
}
public void testEmptySnippet() {
assertQ("null snippet test",
req("q", "text:one OR *:*", "sort", "id asc", "hl", "true"),
"count(//lst[@name='highlighting']/*)=2",
"//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/str='document <em>one</em>'",
"count(//lst[@name='highlighting']/lst[@name='102']/arr[@name='text']/*)=0");
}
public void testDefaultSummary() {
assertQ("null snippet test",
req("q", "text:one OR *:*", "sort", "id asc", "hl", "true", "hl.defaultSummary", "true"),
"count(//lst[@name='highlighting']/*)=2",
"//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/str='document <em>one</em>'",
"//lst[@name='highlighting']/lst[@name='102']/arr[@name='text']/str='second document'");
}
public void testDifferentField() {
assertQ("highlighting text3",
req("q", "text3:document", "sort", "id asc", "hl", "true", "hl.fl", "text3"),
"count(//lst[@name='highlighting']/*)=2",
"//lst[@name='highlighting']/lst[@name='101']/arr[@name='text3']/str='crappy <em>document</em>'",
"//lst[@name='highlighting']/lst[@name='102']/arr[@name='text3']/str='crappier <em>document</em>'");
}
public void testTwoFields() {
assertQ("highlighting text and text3",
req("q", "text:document text3:document", "sort", "id asc", "hl", "true", "hl.fl", "text,text3"),
"count(//lst[@name='highlighting']/*)=2",
"//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/str='<em>document</em> one'",
"//lst[@name='highlighting']/lst[@name='101']/arr[@name='text3']/str='crappy <em>document</em>'",
"//lst[@name='highlighting']/lst[@name='102']/arr[@name='text']/str='second <em>document</em>'",
"//lst[@name='highlighting']/lst[@name='102']/arr[@name='text3']/str='crappier <em>document</em>'");
}
public void testTags() {
assertQ("different pre/post tags",
req("q", "text:document", "sort", "id asc", "hl", "true", "hl.tag.pre", "[", "hl.tag.post", "]"),
"count(//lst[@name='highlighting']/*)=2",
"//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/str='[document] one'",
"//lst[@name='highlighting']/lst[@name='102']/arr[@name='text']/str='second [document]'");
}
public void testUsingSimplePrePostTags() {
assertQ("different pre/post tags",
req("q", "text:document", "sort", "id asc", "hl", "true", "hl.simple.pre", "[", "hl.simple.post", "]"),
"count(//lst[@name='highlighting']/*)=2",
"//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/str='[document] one'",
"//lst[@name='highlighting']/lst[@name='102']/arr[@name='text']/str='second [document]'");
}
public void testUsingSimplePrePostTagsPerField() {
assertQ("different pre/post tags",
req("q", "text:document", "sort", "id asc", "hl", "true", "f.text.hl.simple.pre", "[", "f.text.hl.simple.post", "]"),
"count(//lst[@name='highlighting']/*)=2",
"//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/str='[document] one'",
"//lst[@name='highlighting']/lst[@name='102']/arr[@name='text']/str='second [document]'");
}
public void testTagsPerField() {
assertQ("highlighting text and text3",
req("q", "text:document text3:document", "sort", "id asc", "hl", "true", "hl.fl", "text,text3", "f.text3.hl.tag.pre", "[", "f.text3.hl.tag.post", "]"),
"count(//lst[@name='highlighting']/*)=2",
"//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/str='<em>document</em> one'",
"//lst[@name='highlighting']/lst[@name='101']/arr[@name='text3']/str='crappy [document]'",
"//lst[@name='highlighting']/lst[@name='102']/arr[@name='text']/str='second <em>document</em>'",
"//lst[@name='highlighting']/lst[@name='102']/arr[@name='text3']/str='crappier [document]'");
}
public void testBreakIterator() {
assertQ("different breakiterator",
req("q", "text:document", "sort", "id asc", "hl", "true", "hl.bs.type", "WORD"),
"count(//lst[@name='highlighting']/*)=2",
"//lst[@name='highlighting']/lst[@name='101']/arr[@name='text']/str='<em>document</em>'",
"//lst[@name='highlighting']/lst[@name='102']/arr[@name='text']/str='<em>document</em>'");
}
public void testBreakIterator2() {
assertU(adoc("text", "Document one has a first sentence. Document two has a second sentence.", "id", "103"));
assertU(commit());
assertQ("different breakiterator",
req("q", "text:document", "sort", "id asc", "hl", "true", "hl.bs.type", "WHOLE"),
"//lst[@name='highlighting']/lst[@name='103']/arr[@name='text']/str='<em>Document</em> one has a first sentence. <em>Document</em> two has a second sentence.'");
}
public void testEncoder() {
assertU(adoc("text", "Document one has a first <i>sentence</i>.", "id", "103"));
assertU(commit());
assertQ("html escaped",
req("q", "text:document", "sort", "id asc", "hl", "true", "hl.encoder", "html"),
"//lst[@name='highlighting']/lst[@name='103']/arr[@name='text']/str='<em>Document</em>&#32;one&#32;has&#32;a&#32;first&#32;&lt;i&gt;sentence&lt;&#x2F;i&gt;&#46;'");
}
}

View File

@ -14,71 +14,31 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.update.processor;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.MultiMapSolrParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.params.UpdateParams;
import org.apache.solr.common.util.ContentStream;
import org.apache.solr.common.util.ContentStreamBase;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.UpdateRequestHandler;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.request.SolrQueryRequestBase;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.search.SolrIndexSearcher;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.hamcrest.core.Is.is;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link ClassificationUpdateProcessor} and {@link ClassificationUpdateProcessorFactory}
* Tests for {@link ClassificationUpdateProcessorFactory}
*/
public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
// field names are used in accordance with the solrconfig and schema supplied
private static final String ID = "id";
private static final String TITLE = "title";
private static final String CONTENT = "content";
private static final String AUTHOR = "author";
private static final String CLASS = "cat";
private static final String CHAIN = "classification";
private ClassificationUpdateProcessorFactory cFactoryToTest = new ClassificationUpdateProcessorFactory();
private NamedList args = new NamedList<String>();
@BeforeClass
public static void beforeClass() throws Exception {
System.setProperty("enable.update.log", "false");
initCore("solrconfig-classification.xml", "schema-classification.xml");
}
@Override
@Before
public void setUp() throws Exception {
super.setUp();
clearIndex();
assertU(commit());
}
@Before
public void initArgs() {
args.add("inputFields", "inputField1,inputField2");
args.add("classField", "classField1");
args.add("predictedClassField", "classFieldX");
args.add("algorithm", "bayes");
args.add("knn.k", "9");
args.add("knn.minDf", "8");
@ -86,22 +46,23 @@ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
}
@Test
public void testFullInit() {
public void init_fullArgs_shouldInitFullClassificationParams() {
cFactoryToTest.init(args);
ClassificationUpdateProcessorParams classificationParams = cFactoryToTest.getClassificationParams();
String[] inputFieldNames = cFactoryToTest.getInputFieldNames();
String[] inputFieldNames = classificationParams.getInputFieldNames();
assertEquals("inputField1", inputFieldNames[0]);
assertEquals("inputField2", inputFieldNames[1]);
assertEquals("classField1", cFactoryToTest.getClassFieldName());
assertEquals("bayes", cFactoryToTest.getAlgorithm());
assertEquals(8, cFactoryToTest.getMinDf());
assertEquals(10, cFactoryToTest.getMinTf());
assertEquals(9, cFactoryToTest.getK());
assertEquals("classField1", classificationParams.getTrainingClassField());
assertEquals("classFieldX", classificationParams.getPredictedClassField());
assertEquals(ClassificationUpdateProcessorFactory.Algorithm.BAYES, classificationParams.getAlgorithm());
assertEquals(8, classificationParams.getMinDf());
assertEquals(10, classificationParams.getMinTf());
assertEquals(9, classificationParams.getK());
}
@Test
public void testInitEmptyInputField() {
public void init_emptyInputFields_shouldThrowExceptionWithDetailedMessage() {
args.removeAll("inputFields");
try {
cFactoryToTest.init(args);
@ -111,7 +72,7 @@ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
}
@Test
public void testInitEmptyClassField() {
public void init_emptyClassField_shouldThrowExceptionWithDetailedMessage() {
args.removeAll("classField");
try {
cFactoryToTest.init(args);
@ -121,114 +82,53 @@ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
}
@Test
public void testDefaults() {
public void init_emptyPredictedClassField_shouldDefaultToTrainingClassField() {
args.removeAll("predictedClassField");
cFactoryToTest.init(args);
ClassificationUpdateProcessorParams classificationParams = cFactoryToTest.getClassificationParams();
assertThat(classificationParams.getPredictedClassField(), is("classField1"));
}
@Test
public void init_unsupportedAlgorithm_shouldThrowExceptionWithDetailedMessage() {
args.removeAll("algorithm");
args.add("algorithm", "unsupported");
try {
cFactoryToTest.init(args);
} catch (SolrException e) {
assertEquals("Classification UpdateProcessor Algorithm: 'unsupported' not supported", e.getMessage());
}
}
@Test
public void init_unsupportedFilterQuery_shouldThrowExceptionWithDetailedMessage() {
UpdateRequestProcessor mockProcessor = mock(UpdateRequestProcessor.class);
SolrQueryRequest mockRequest = mock(SolrQueryRequest.class);
SolrQueryResponse mockResponse = mock(SolrQueryResponse.class);
args.add("knn.filterQuery", "not supported query");
try {
cFactoryToTest.init(args);
/* parsing failure happens because of the mocks, fine enough to check a proper exception propagation */
cFactoryToTest.getInstance(mockRequest, mockResponse, mockProcessor);
} catch (SolrException e) {
assertEquals("Classification UpdateProcessor Training Filter Query: 'not supported query' is not supported", e.getMessage());
}
}
@Test
public void init_emptyArgs_shouldDefaultClassificationParams() {
args.removeAll("algorithm");
args.removeAll("knn.k");
args.removeAll("knn.minDf");
args.removeAll("knn.minTf");
cFactoryToTest.init(args);
assertEquals("knn", cFactoryToTest.getAlgorithm());
assertEquals(1, cFactoryToTest.getMinDf());
assertEquals(1, cFactoryToTest.getMinTf());
assertEquals(10, cFactoryToTest.getK());
}
ClassificationUpdateProcessorParams classificationParams = cFactoryToTest.getClassificationParams();
@Test
public void testBasicClassification() throws Exception {
prepareTrainedIndex();
// To be classified,we index documents without a class and verify the expected one is returned
addDoc(adoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 ",
AUTHOR, "Name1 Surname1"));
addDoc(adoc(ID, "11",
TITLE, "word1 word1",
CONTENT, "word2 word2",
AUTHOR, "Name Surname"));
addDoc(commit());
Document doc10 = getDoc("10");
assertEquals("class2", doc10.get(CLASS));
Document doc11 = getDoc("11");
assertEquals("class1", doc11.get(CLASS));
}
/**
* Index some example documents with a class manually assigned.
* This will be our trained model.
*
* @throws Exception If there is a low-level I/O error
*/
private void prepareTrainedIndex() throws Exception {
//class1
addDoc(adoc(ID, "1",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "Name Surname",
CLASS, "class1"));
addDoc(adoc(ID, "2",
TITLE, "word1 word1",
CONTENT, "word2 word2",
AUTHOR, "Name Surname",
CLASS, "class1"));
addDoc(adoc(ID, "3",
TITLE, "word1 word1 word1",
CONTENT, "word2",
AUTHOR, "Name Surname",
CLASS, "class1"));
addDoc(adoc(ID, "4",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "Name Surname",
CLASS, "class1"));
//class2
addDoc(adoc(ID, "5",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5",
AUTHOR, "Name1 Surname1",
CLASS, "class2"));
addDoc(adoc(ID, "6",
TITLE, "word4 word4",
CONTENT, "word5",
AUTHOR, "Name1 Surname1",
CLASS, "class2"));
addDoc(adoc(ID, "7",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 word5",
AUTHOR, "Name1 Surname1",
CLASS, "class2"));
addDoc(adoc(ID, "8",
TITLE, "word4",
CONTENT, "word5 word5 word5 word5",
AUTHOR, "Name1 Surname1",
CLASS, "class2"));
addDoc(commit());
}
private Document getDoc(String id) throws IOException {
try (SolrQueryRequest req = req()) {
SolrIndexSearcher searcher = req.getSearcher();
TermQuery query = new TermQuery(new Term(ID, id));
TopDocs doc1 = searcher.search(query, 1);
ScoreDoc scoreDoc = doc1.scoreDocs[0];
return searcher.doc(scoreDoc.doc);
}
}
static void addDoc(String doc) throws Exception {
Map<String, String[]> params = new HashMap<>();
MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
params.put(UpdateParams.UPDATE_CHAIN, new String[]{CHAIN});
SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(),
(SolrParams) mmparams) {
};
UpdateRequestHandler handler = new UpdateRequestHandler();
handler.init(null);
ArrayList<ContentStream> streams = new ArrayList<>(2);
streams.add(new ContentStreamBase.StringStream(doc));
req.setContentStreams(streams);
handler.handleRequestBody(req, new SolrQueryResponse());
req.close();
assertEquals(ClassificationUpdateProcessorFactory.Algorithm.KNN, classificationParams.getAlgorithm());
assertEquals(1, classificationParams.getMinDf());
assertEquals(1, classificationParams.getMinTf());
assertEquals(10, classificationParams.getK());
}
}

View File

@ -0,0 +1,192 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.update.processor;
import java.io.IOException;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.search.SolrIndexSearcher;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.hamcrest.core.Is.is;
/**
* Tests for {@link ClassificationUpdateProcessor} and {@link ClassificationUpdateProcessorFactory}
*/
public class ClassificationUpdateProcessorIntegrationTest extends SolrTestCaseJ4 {
/* field names are used in accordance with the solrconfig and schema supplied */
private static final String ID = "id";
private static final String TITLE = "title";
private static final String CONTENT = "content";
private static final String AUTHOR = "author";
private static final String CLASS = "cat";
private static final String CHAIN = "classification";
private static final String BROKEN_CHAIN_FILTER_QUERY = "classification-unsupported-filterQuery";
private ClassificationUpdateProcessorFactory cFactoryToTest = new ClassificationUpdateProcessorFactory();
private NamedList args = new NamedList<String>();
@BeforeClass
public static void beforeClass() throws Exception {
System.setProperty("enable.update.log", "false");
initCore("solrconfig-classification.xml", "schema-classification.xml");
}
@Override
@Before
public void setUp() throws Exception {
super.setUp();
clearIndex();
assertU(commit());
}
@Test
public void classify_fullConfiguration_shouldAutoClassify() throws Exception {
indexTrainingSet();
// To be classified,we index documents without a class and verify the expected one is returned
addDoc(adoc(ID, "22",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 ",
AUTHOR, "Name1 Surname1"), CHAIN);
addDoc(adoc(ID, "21",
TITLE, "word1 word1",
CONTENT, "word2 word2",
AUTHOR, "Name Surname"), CHAIN);
addDoc(commit());
Document doc22 = getDoc("22");
assertThat(doc22.get(CLASS),is("class2"));
Document doc21 = getDoc("21");
assertThat(doc21.get(CLASS),is("class1"));
}
@Test
public void classify_unsupportedFilterQueryConfiguration_shouldThrowExceptionWithDetailedMessage() throws Exception {
indexTrainingSet();
try {
addDoc(adoc(ID, "21",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 ",
AUTHOR, "Name1 Surname1"), BROKEN_CHAIN_FILTER_QUERY);
addDoc(adoc(ID, "22",
TITLE, "word1 word1",
CONTENT, "word2 word2",
AUTHOR, "Name Surname"), BROKEN_CHAIN_FILTER_QUERY);
addDoc(commit());
} catch (SolrException e) {
assertEquals("Classification UpdateProcessor Training Filter Query: 'not valid ( lucene query' is not supported", e.getMessage());
}
}
/**
* Index some example documents with a class manually assigned.
* This will be our trained model.
*
* @throws Exception If there is a low-level I/O error
*/
private void indexTrainingSet() throws Exception {
//class1
addDoc(adoc(ID, "1",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "Name Surname",
CLASS, "class1"), CHAIN);
addDoc(adoc(ID, "2",
TITLE, "word1 word1",
CONTENT, "word2 word2",
AUTHOR, "Name Surname",
CLASS, "class1"), CHAIN);
addDoc(adoc(ID, "3",
TITLE, "word1 word1 word1",
CONTENT, "word2",
AUTHOR, "Name Surname",
CLASS, "class1"), CHAIN);
addDoc(adoc(ID, "4",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "Name Surname",
CLASS, "class1"), CHAIN);
//class2
addDoc(adoc(ID, "5",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5",
AUTHOR, "Name Surname",
CLASS, "class2"), CHAIN);
addDoc(adoc(ID, "6",
TITLE, "word4 word4",
CONTENT, "word5",
AUTHOR, "Name Surname",
CLASS, "class2"), CHAIN);
addDoc(adoc(ID, "7",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 word5",
AUTHOR, "Name Surname",
CLASS, "class2"), CHAIN);
addDoc(adoc(ID, "8",
TITLE, "word4",
CONTENT, "word5 word5 word5 word5",
AUTHOR, "Name Surname",
CLASS, "class2"), CHAIN);
//class3
addDoc(adoc(ID, "9",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5",
AUTHOR, "Name1 Surname1",
CLASS, "class3"), CHAIN);
addDoc(adoc(ID, "10",
TITLE, "word4 word4",
CONTENT, "word5",
AUTHOR, "Name1 Surname1",
CLASS, "class3"), CHAIN);
addDoc(adoc(ID, "11",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 word5",
AUTHOR, "Name1 Surname1",
CLASS, "class3"), CHAIN);
addDoc(adoc(ID, "12",
TITLE, "word4",
CONTENT, "word5 word5 word5 word5",
AUTHOR, "Name1 Surname1",
CLASS, "class3"), CHAIN);
addDoc(commit());
}
private Document getDoc(String id) throws IOException {
try (SolrQueryRequest req = req()) {
SolrIndexSearcher searcher = req.getSearcher();
TermQuery query = new TermQuery(new Term(ID, id));
TopDocs doc1 = searcher.search(query, 1);
ScoreDoc scoreDoc = doc1.scoreDocs[0];
return searcher.doc(scoreDoc.doc);
}
}
private void addDoc(String doc) throws Exception {
addDoc(doc, CHAIN);
}
}

View File

@ -0,0 +1,506 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.update.processor;
import java.io.IOException;
import java.util.ArrayList;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.analysis.MockTokenizer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.update.AddUpdateCommand;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.hamcrest.core.Is.is;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link ClassificationUpdateProcessor}
*/
public class ClassificationUpdateProcessorTest extends SolrTestCaseJ4 {
/* field names are used in accordance with the solrconfig and schema supplied */
private static final String ID = "id";
private static final String TITLE = "title";
private static final String CONTENT = "content";
private static final String AUTHOR = "author";
private static final String TRAINING_CLASS = "cat";
private static final String PREDICTED_CLASS = "predicted";
public static final String KNN = "knn";
protected Directory directory;
protected IndexReader reader;
protected IndexSearcher searcher;
protected Analyzer analyzer = new MockAnalyzer(random(), MockTokenizer.WHITESPACE, false);
private ClassificationUpdateProcessor updateProcessorToTest;
@BeforeClass
public static void beforeClass() throws Exception {
System.setProperty("enable.update.log", "false");
initCore("solrconfig-classification.xml", "schema-classification.xml");
}
@Override
public void setUp() throws Exception {
super.setUp();
}
@Override
public void tearDown() throws Exception {
reader.close();
directory.close();
analyzer.close();
super.tearDown();
}
@Test
public void classificationMonoClass_predictedClassFieldSet_shouldAssignClassInPredictedClassField() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMonoClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params = initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
params.setPredictedClassField(PREDICTED_CLASS);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
assertThat(unseenDocument1.getFieldValue(PREDICTED_CLASS),is("class1"));
}
@Test
public void knnMonoClass_sampleParams_shouldAssignCorrectClass() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMonoClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params = initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class1"));
}
@Test
public void knnMonoClass_boostFields_shouldAssignCorrectClass() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMonoClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params = initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
params.setInputFieldNames(new String[]{TITLE + "^1.5", CONTENT + "^0.5", AUTHOR + "^2.5"});
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class2"));
}
@Test
public void bayesMonoClass_sampleParams_shouldAssignCorrectClass() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMonoClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class1"));
}
@Test
public void knnMonoClass_contextQueryFiltered_shouldAssignCorrectClass() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMonoClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "a");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
Query class3DocsChunk=new TermQuery(new Term(TITLE,"word6"));
params.setTrainingFilterQuery(class3DocsChunk);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class3"));
}
@Test
public void bayesMonoClass_boostFields_shouldAssignCorrectClass() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMonoClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
params.setInputFieldNames(new String[]{TITLE+"^1.5",CONTENT+"^0.5",AUTHOR+"^2.5"});
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class2"));
}
@Test
public void knnClassification_maxOutputClassesGreaterThanAvailable_shouldAssignCorrectClass() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMultiClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
params.setMaxPredictedClasses(100);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
assertThat(assignedClasses.get(0),is("class2"));
assertThat(assignedClasses.get(1),is("class1"));
}
@Test
public void knnMultiClass_maxOutputClasses2_shouldAssignMax2Classes() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMultiClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
params.setMaxPredictedClasses(2);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
assertThat(assignedClasses.size(),is(2));
assertThat(assignedClasses.get(0),is("class2"));
assertThat(assignedClasses.get(1),is("class1"));
}
@Test
public void bayesMultiClass_maxOutputClasses2_shouldAssignMax2Classes() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMultiClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
params.setMaxPredictedClasses(2);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
assertThat(assignedClasses.size(),is(2));
assertThat(assignedClasses.get(0),is("class2"));
assertThat(assignedClasses.get(1),is("class1"));
}
@Test
public void knnMultiClass_boostFieldsMaxOutputClasses2_shouldAssignMax2Classes() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMultiClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
params.setInputFieldNames(new String[]{TITLE+"^1.5",CONTENT+"^0.5",AUTHOR+"^2.5"});
params.setMaxPredictedClasses(2);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
assertThat(assignedClasses.size(),is(2));
assertThat(assignedClasses.get(0),is("class4"));
assertThat(assignedClasses.get(1),is("class6"));
}
@Test
public void bayesMultiClass_boostFieldsMaxOutputClasses2_shouldAssignMax2Classes() throws Exception {
UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
prepareTrainedIndexMultiClass();
AddUpdateCommand update=new AddUpdateCommand(req());
SolrInputDocument unseenDocument1 = sdoc(ID, "10",
TITLE, "word4 word4 word4",
CONTENT, "word2 word2 ",
AUTHOR, "unseenAuthor");
update.solrDoc=unseenDocument1;
ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
params.setInputFieldNames(new String[]{TITLE+"^1.5",CONTENT+"^0.5",AUTHOR+"^2.5"});
params.setMaxPredictedClasses(2);
updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
updateProcessorToTest.processAdd(update);
ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
assertThat(assignedClasses.size(),is(2));
assertThat(assignedClasses.get(0),is("class4"));
assertThat(assignedClasses.get(1),is("class6"));
}
private ClassificationUpdateProcessorParams initParams(ClassificationUpdateProcessorFactory.Algorithm classificationAlgorithm) {
ClassificationUpdateProcessorParams params= new ClassificationUpdateProcessorParams();
params.setInputFieldNames(new String[]{TITLE,CONTENT,AUTHOR});
params.setTrainingClassField(TRAINING_CLASS);
params.setPredictedClassField(TRAINING_CLASS);
params.setMinTf(1);
params.setMinDf(1);
params.setK(5);
params.setAlgorithm(classificationAlgorithm);
params.setMaxPredictedClasses(1);
return params;
}
/**
* Index some example documents with a class manually assigned.
* This will be our trained model.
*
* @throws Exception If there is a low-level I/O error
*/
private void prepareTrainedIndexMonoClass() throws Exception {
directory = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), directory);
//class1
addDoc(writer, buildLuceneDocument(ID, "1",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "a",
TRAINING_CLASS, "class1"));
addDoc(writer, buildLuceneDocument(ID, "2",
TITLE, "word1 word1",
CONTENT, "word2 word2",
AUTHOR, "a",
TRAINING_CLASS, "class1"));
addDoc(writer, buildLuceneDocument(ID, "3",
TITLE, "word1 word1 word1",
CONTENT, "word2",
AUTHOR, "a",
TRAINING_CLASS, "class1"));
addDoc(writer, buildLuceneDocument(ID, "4",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "a",
TRAINING_CLASS, "class1"));
//class2
addDoc(writer, buildLuceneDocument(ID, "5",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5",
AUTHOR, "c",
TRAINING_CLASS, "class2"));
addDoc(writer, buildLuceneDocument(ID, "6",
TITLE, "word4 word4",
CONTENT, "word5",
AUTHOR, "c",
TRAINING_CLASS, "class2"));
addDoc(writer, buildLuceneDocument(ID, "7",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 word5",
AUTHOR, "c",
TRAINING_CLASS, "class2"));
addDoc(writer, buildLuceneDocument(ID, "8",
TITLE, "word4",
CONTENT, "word5 word5 word5 word5",
AUTHOR, "c",
TRAINING_CLASS, "class2"));
//class3
addDoc(writer, buildLuceneDocument(ID, "9",
TITLE, "word6",
CONTENT, "word7",
AUTHOR, "a",
TRAINING_CLASS, "class3"));
addDoc(writer, buildLuceneDocument(ID, "10",
TITLE, "word6",
CONTENT, "word7",
AUTHOR, "a",
TRAINING_CLASS, "class3"));
addDoc(writer, buildLuceneDocument(ID, "11",
TITLE, "word6",
CONTENT, "word7",
AUTHOR, "a",
TRAINING_CLASS, "class3"));
addDoc(writer, buildLuceneDocument(ID, "12",
TITLE, "word6",
CONTENT, "word7",
AUTHOR, "a",
TRAINING_CLASS, "class3"));
reader = writer.getReader();
writer.close();
searcher = newSearcher(reader);
}
private void prepareTrainedIndexMultiClass() throws Exception {
directory = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), directory);
//class1
addDoc(writer, buildLuceneDocument(ID, "1",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "Name Surname",
TRAINING_CLASS, "class1",
TRAINING_CLASS, "class2"
));
addDoc(writer, buildLuceneDocument(ID, "2",
TITLE, "word1 word1",
CONTENT, "word2 word2",
AUTHOR, "Name Surname",
TRAINING_CLASS, "class3",
TRAINING_CLASS, "class2"
));
addDoc(writer, buildLuceneDocument(ID, "3",
TITLE, "word1 word1 word1",
CONTENT, "word2",
AUTHOR, "Name Surname",
TRAINING_CLASS, "class1",
TRAINING_CLASS, "class2"
));
addDoc(writer, buildLuceneDocument(ID, "4",
TITLE, "word1 word1 word1",
CONTENT, "word2 word2 word2",
AUTHOR, "Name Surname",
TRAINING_CLASS, "class1",
TRAINING_CLASS, "class2"
));
//class2
addDoc(writer, buildLuceneDocument(ID, "5",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5",
AUTHOR, "Name1 Surname1",
TRAINING_CLASS, "class6",
TRAINING_CLASS, "class4"
));
addDoc(writer, buildLuceneDocument(ID, "6",
TITLE, "word4 word4",
CONTENT, "word5",
AUTHOR, "Name1 Surname1",
TRAINING_CLASS, "class5",
TRAINING_CLASS, "class4"
));
addDoc(writer, buildLuceneDocument(ID, "7",
TITLE, "word4 word4 word4",
CONTENT, "word5 word5 word5",
AUTHOR, "Name1 Surname1",
TRAINING_CLASS, "class6",
TRAINING_CLASS, "class4"
));
addDoc(writer, buildLuceneDocument(ID, "8",
TITLE, "word4",
CONTENT, "word5 word5 word5 word5",
AUTHOR, "Name1 Surname1",
TRAINING_CLASS, "class6",
TRAINING_CLASS, "class4"
));
reader = writer.getReader();
writer.close();
searcher = newSearcher(reader);
}
public static Document buildLuceneDocument(Object... fieldsAndValues) {
Document luceneDoc = new Document();
for (int i=0; i<fieldsAndValues.length; i+=2) {
luceneDoc.add(newTextField((String)fieldsAndValues[i], (String)fieldsAndValues[i+1], Field.Store.YES));
}
return luceneDoc;
}
private int addDoc(RandomIndexWriter writer, Document doc) throws IOException {
writer.addDocument(doc);
return writer.numDocs() - 1;
}
}

View File

@ -16,31 +16,28 @@
*/
package org.apache.solr.update.processor;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.util.Constants;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.client.solrj.impl.BinaryRequestWriter;
import org.apache.solr.client.solrj.request.UpdateRequest;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.common.params.MultiMapSolrParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.params.UpdateParams;
import org.apache.solr.common.util.ContentStream;
import org.apache.solr.common.util.ContentStreamBase;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrCore;
import org.apache.solr.handler.UpdateRequestHandler;
import org.apache.solr.request.LocalSolrQueryRequest;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.request.SolrQueryRequestBase;
import org.apache.solr.response.SolrQueryResponse;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
/**
*
*/
@ -359,21 +356,4 @@ public class SignatureUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
private void addDoc(String doc) throws Exception {
addDoc(doc, chain);
}
static void addDoc(String doc, String chain) throws Exception {
Map<String, String[]> params = new HashMap<>();
MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
params.put(UpdateParams.UPDATE_CHAIN, new String[] { chain });
SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(),
(SolrParams) mmparams) {
};
UpdateRequestHandler handler = new UpdateRequestHandler();
handler.init(null);
ArrayList<ContentStream> streams = new ArrayList<>(2);
streams.add(new ContentStreamBase.StringStream(doc));
req.setContentStreams(streams);
handler.handleRequestBody(req, new SolrQueryResponse());
req.close();
}
}

View File

@ -25,8 +25,6 @@ import org.junit.Test;
import java.util.Map;
import static org.apache.solr.update.processor.SignatureUpdateProcessorFactoryTest.addDoc;
public class TestPartialUpdateDeduplication extends SolrTestCaseJ4 {
@BeforeClass
public static void beforeClass() throws Exception {

View File

@ -21,62 +21,76 @@ package org.apache.solr.common.params;
* @since solr 1.3
*/
public interface HighlightParams {
// primary
public static final String HIGHLIGHT = "hl";
public static final String Q = HIGHLIGHT+".q";
public static final String QPARSER = HIGHLIGHT+".qparser";
public static final String METHOD = HIGHLIGHT+".method"; // original|fastVector|postings|unified
@Deprecated // see hl.method
public static final String USE_FVH = HIGHLIGHT + ".useFastVectorHighlighter";
public static final String FIELDS = HIGHLIGHT+".fl";
public static final String SNIPPETS = HIGHLIGHT+".snippets";
public static final String FRAGSIZE = HIGHLIGHT+".fragsize";
public static final String INCREMENT = HIGHLIGHT+".increment";
public static final String MAX_CHARS = HIGHLIGHT+".maxAnalyzedChars";
public static final String FORMATTER = HIGHLIGHT+".formatter";
public static final String ENCODER = HIGHLIGHT+".encoder";
public static final String FRAGMENTER = HIGHLIGHT+".fragmenter";
public static final String PRESERVE_MULTI = HIGHLIGHT+".preserveMulti";
public static final String FRAG_LIST_BUILDER = HIGHLIGHT+".fragListBuilder";
public static final String FRAGMENTS_BUILDER = HIGHLIGHT+".fragmentsBuilder";
public static final String BOUNDARY_SCANNER = HIGHLIGHT+".boundaryScanner";
public static final String BS_MAX_SCAN = HIGHLIGHT+".bs.maxScan";
public static final String BS_CHARS = HIGHLIGHT+".bs.chars";
public static final String BS_TYPE = HIGHLIGHT+".bs.type";
public static final String BS_LANGUAGE = HIGHLIGHT+".bs.language";
public static final String BS_COUNTRY = HIGHLIGHT+".bs.country";
public static final String BS_VARIANT = HIGHLIGHT+".bs.variant";
public static final String FIELD_MATCH = HIGHLIGHT+".requireFieldMatch";
public static final String DEFAULT_SUMMARY = HIGHLIGHT + ".defaultSummary";
public static final String ALTERNATE_FIELD = HIGHLIGHT+".alternateField";
public static final String ALTERNATE_FIELD_LENGTH = HIGHLIGHT+".maxAlternateFieldLength";
public static final String HIGHLIGHT_ALTERNATE = HIGHLIGHT+".highlightAlternate";
public static final String MAX_MULTIVALUED_TO_EXAMINE = HIGHLIGHT + ".maxMultiValuedToExamine";
public static final String MAX_MULTIVALUED_TO_MATCH = HIGHLIGHT + ".maxMultiValuedToMatch";
public static final String USE_PHRASE_HIGHLIGHTER = HIGHLIGHT+".usePhraseHighlighter";
public static final String HIGHLIGHT_MULTI_TERM = HIGHLIGHT+".highlightMultiTerm";
public static final String PAYLOADS = HIGHLIGHT+".payloads";
public static final String MERGE_CONTIGUOUS_FRAGMENTS = HIGHLIGHT + ".mergeContiguous";
// KEY:
// OH = (original) Highlighter (AKA the standard Highlighter)
// FVH = FastVectorHighlighter
// PH = PostingsHighlighter
// UH = UnifiedHighlighter
public static final String USE_FVH = HIGHLIGHT + ".useFastVectorHighlighter";
public static final String TAG_PRE = HIGHLIGHT + ".tag.pre";
public static final String TAG_POST = HIGHLIGHT + ".tag.post";
public static final String TAG_ELLIPSIS = HIGHLIGHT + ".tag.ellipsis";
public static final String PHRASE_LIMIT = HIGHLIGHT + ".phraseLimit";
public static final String MULTI_VALUED_SEPARATOR = HIGHLIGHT + ".multiValuedSeparatorChar";
// Formatter
public static final String SIMPLE = "simple";
public static final String SIMPLE_PRE = HIGHLIGHT+"."+SIMPLE+".pre";
public static final String SIMPLE_POST = HIGHLIGHT+"."+SIMPLE+".post";
// query interpretation
public static final String Q = HIGHLIGHT+".q"; // all
public static final String QPARSER = HIGHLIGHT+".qparser"; // all
public static final String FIELD_MATCH = HIGHLIGHT+".requireFieldMatch"; // OH, FVH
public static final String USE_PHRASE_HIGHLIGHTER = HIGHLIGHT+".usePhraseHighlighter"; // OH, FVH, UH
public static final String HIGHLIGHT_MULTI_TERM = HIGHLIGHT+".highlightMultiTerm"; // all
// Regex fragmenter
public static final String REGEX = "regex";
public static final String SLOP = HIGHLIGHT+"."+REGEX+".slop";
public static final String PATTERN = HIGHLIGHT+"."+REGEX+".pattern";
public static final String MAX_RE_CHARS = HIGHLIGHT+"."+REGEX+".maxAnalyzedChars";
// Scoring parameters
public static final String SCORE = "score";
public static final String SCORE_K1 = HIGHLIGHT +"."+SCORE+".k1";
public static final String SCORE_B = HIGHLIGHT +"."+SCORE+".b";
public static final String SCORE_PIVOT = HIGHLIGHT +"."+SCORE+".pivot";
// if no snippets...
public static final String DEFAULT_SUMMARY = HIGHLIGHT + ".defaultSummary"; // UH, PH
public static final String ALTERNATE_FIELD = HIGHLIGHT+".alternateField"; // OH, FVH
public static final String ALTERNATE_FIELD_LENGTH = HIGHLIGHT+".maxAlternateFieldLength"; // OH, FVH
public static final String HIGHLIGHT_ALTERNATE = HIGHLIGHT+".highlightAlternate"; // OH, FVH
// sizing
public static final String FRAGSIZE = HIGHLIGHT+".fragsize"; // OH, FVH
public static final String FRAGMENTER = HIGHLIGHT+".fragmenter"; // OH
public static final String INCREMENT = HIGHLIGHT+".increment"; // OH
public static final String REGEX = "regex"; // OH
public static final String SLOP = HIGHLIGHT+"."+REGEX+".slop"; // OH
public static final String PATTERN = HIGHLIGHT+"."+REGEX+".pattern"; // OH
public static final String MAX_RE_CHARS= HIGHLIGHT+"."+REGEX+".maxAnalyzedChars"; // OH
public static final String BOUNDARY_SCANNER = HIGHLIGHT+".boundaryScanner"; // FVH
public static final String BS_MAX_SCAN = HIGHLIGHT+".bs.maxScan"; // FVH
public static final String BS_CHARS = HIGHLIGHT+".bs.chars"; // FVH
public static final String BS_TYPE = HIGHLIGHT+".bs.type"; // FVH, UH, PH
public static final String BS_LANGUAGE = HIGHLIGHT+".bs.language"; // FVH, UH, PH
public static final String BS_COUNTRY = HIGHLIGHT+".bs.country"; // FVH, UH, PH
public static final String BS_VARIANT = HIGHLIGHT+".bs.variant"; // FVH, UH, PH
// formatting
public static final String FORMATTER = HIGHLIGHT+".formatter"; // OH
public static final String ENCODER = HIGHLIGHT+".encoder"; // OH, (UH, PH limited)
public static final String MERGE_CONTIGUOUS_FRAGMENTS = HIGHLIGHT + ".mergeContiguous"; // OH
public static final String SIMPLE = "simple"; // OH
public static final String SIMPLE_PRE = HIGHLIGHT+"."+SIMPLE+".pre"; // OH
public static final String SIMPLE_POST = HIGHLIGHT+"."+SIMPLE+".post"; // OH
public static final String FRAGMENTS_BUILDER = HIGHLIGHT+".fragmentsBuilder"; // FVH
public static final String TAG_PRE = HIGHLIGHT + ".tag.pre"; // FVH, UH, PH
public static final String TAG_POST = HIGHLIGHT + ".tag.post"; // FVH, UH, PH
public static final String TAG_ELLIPSIS= HIGHLIGHT + ".tag.ellipsis"; // FVH, UH, PH
public static final String MULTI_VALUED_SEPARATOR = HIGHLIGHT + ".multiValuedSeparatorChar"; // FVH, PH
// ordering
public static final String PRESERVE_MULTI = HIGHLIGHT+".preserveMulti"; // OH
public static final String FRAG_LIST_BUILDER = HIGHLIGHT+".fragListBuilder"; // FVH
public static final String SCORE = "score"; // UH, PH
public static final String SCORE_K1 = HIGHLIGHT +"."+SCORE+".k1"; // UH, PH
public static final String SCORE_B = HIGHLIGHT +"."+SCORE+".b"; // UH, PH
public static final String SCORE_PIVOT = HIGHLIGHT +"."+SCORE+".pivot"; // UH, PH
// misc
public static final String MAX_CHARS = HIGHLIGHT+".maxAnalyzedChars"; // all
public static final String PAYLOADS = HIGHLIGHT+".payloads"; // OH
public static final String MAX_MULTIVALUED_TO_EXAMINE = HIGHLIGHT + ".maxMultiValuedToExamine"; // OH
public static final String MAX_MULTIVALUED_TO_MATCH = HIGHLIGHT + ".maxMultiValuedToMatch"; // OH
public static final String PHRASE_LIMIT = HIGHLIGHT + ".phraseLimit"; // FVH
public static final String OFFSET_SOURCE = HIGHLIGHT + ".offsetSource"; // UH
public static final String CACHE_FIELD_VAL_CHARS_THRESHOLD = HIGHLIGHT + ".cacheFieldValCharsThreshold"; // UH
}

View File

@ -83,7 +83,11 @@ import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.common.SolrInputField;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.params.MultiMapSolrParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.params.UpdateParams;
import org.apache.solr.common.util.ContentStream;
import org.apache.solr.common.util.ContentStreamBase;
import org.apache.solr.common.util.ObjectReleaseTracker;
import org.apache.solr.common.util.XML;
import org.apache.solr.core.CoreContainer;
@ -96,7 +100,9 @@ import org.apache.solr.core.SolrXmlConfig;
import org.apache.solr.handler.UpdateRequestHandler;
import org.apache.solr.request.LocalSolrQueryRequest;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.request.SolrQueryRequestBase;
import org.apache.solr.request.SolrRequestHandler;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.SolrIndexSearcher;
@ -1009,6 +1015,22 @@ public abstract class SolrTestCaseJ4 extends LuceneTestCase {
return out.toString();
}
public static void addDoc(String doc, String updateRequestProcessorChain) throws Exception {
Map<String, String[]> params = new HashMap<>();
MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
params.put(UpdateParams.UPDATE_CHAIN, new String[]{updateRequestProcessorChain});
SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(),
(SolrParams) mmparams) {
};
UpdateRequestHandler handler = new UpdateRequestHandler();
handler.init(null);
ArrayList<ContentStream> streams = new ArrayList<>(2);
streams.add(new ContentStreamBase.StringStream(doc));
req.setContentStreams(streams);
handler.handleRequestBody(req, new SolrQueryResponse());
req.close();
}
/**
* Generates an &lt;add&gt;&lt;doc&gt;... XML String with options