mirror of
https://github.com/apache/lucene.git
synced 2025-02-24 11:16:35 +00:00
SOLR-10156: Add significantTerms Streaming Expression
This commit is contained in:
parent
894a43b259
commit
dba733e7aa
@ -51,41 +51,7 @@ import org.apache.solr.client.solrj.io.ops.ConcatOperation;
|
||||
import org.apache.solr.client.solrj.io.ops.DistinctOperation;
|
||||
import org.apache.solr.client.solrj.io.ops.GroupOperation;
|
||||
import org.apache.solr.client.solrj.io.ops.ReplaceOperation;
|
||||
import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
|
||||
import org.apache.solr.client.solrj.io.stream.CommitStream;
|
||||
import org.apache.solr.client.solrj.io.stream.ComplementStream;
|
||||
import org.apache.solr.client.solrj.io.stream.DaemonStream;
|
||||
import org.apache.solr.client.solrj.io.stream.ExceptionStream;
|
||||
import org.apache.solr.client.solrj.io.stream.ExecutorStream;
|
||||
import org.apache.solr.client.solrj.io.stream.FacetStream;
|
||||
import org.apache.solr.client.solrj.io.stream.FeaturesSelectionStream;
|
||||
import org.apache.solr.client.solrj.io.stream.FetchStream;
|
||||
import org.apache.solr.client.solrj.io.stream.HashJoinStream;
|
||||
import org.apache.solr.client.solrj.io.stream.HavingStream;
|
||||
import org.apache.solr.client.solrj.io.stream.InnerJoinStream;
|
||||
import org.apache.solr.client.solrj.io.stream.IntersectStream;
|
||||
import org.apache.solr.client.solrj.io.stream.JDBCStream;
|
||||
import org.apache.solr.client.solrj.io.stream.LeftOuterJoinStream;
|
||||
import org.apache.solr.client.solrj.io.stream.MergeStream;
|
||||
import org.apache.solr.client.solrj.io.stream.ModelStream;
|
||||
import org.apache.solr.client.solrj.io.stream.NullStream;
|
||||
import org.apache.solr.client.solrj.io.stream.OuterHashJoinStream;
|
||||
import org.apache.solr.client.solrj.io.stream.ParallelStream;
|
||||
import org.apache.solr.client.solrj.io.stream.PriorityStream;
|
||||
import org.apache.solr.client.solrj.io.stream.RandomStream;
|
||||
import org.apache.solr.client.solrj.io.stream.RankStream;
|
||||
import org.apache.solr.client.solrj.io.stream.ReducerStream;
|
||||
import org.apache.solr.client.solrj.io.stream.RollupStream;
|
||||
import org.apache.solr.client.solrj.io.stream.ScoreNodesStream;
|
||||
import org.apache.solr.client.solrj.io.stream.SelectStream;
|
||||
import org.apache.solr.client.solrj.io.stream.SortStream;
|
||||
import org.apache.solr.client.solrj.io.stream.StatsStream;
|
||||
import org.apache.solr.client.solrj.io.stream.StreamContext;
|
||||
import org.apache.solr.client.solrj.io.stream.TextLogitStream;
|
||||
import org.apache.solr.client.solrj.io.stream.TopicStream;
|
||||
import org.apache.solr.client.solrj.io.stream.TupleStream;
|
||||
import org.apache.solr.client.solrj.io.stream.UniqueStream;
|
||||
import org.apache.solr.client.solrj.io.stream.UpdateStream;
|
||||
import org.apache.solr.client.solrj.io.stream.*;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
|
||||
@ -193,7 +159,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
|
||||
.withFunctionName("executor", ExecutorStream.class)
|
||||
.withFunctionName("null", NullStream.class)
|
||||
.withFunctionName("priority", PriorityStream.class)
|
||||
|
||||
.withFunctionName("significantTerms", SignificantTermsStream.class)
|
||||
// metrics
|
||||
.withFunctionName("min", MinMetric.class)
|
||||
.withFunctionName("max", MaxMetric.class)
|
||||
|
@ -79,6 +79,8 @@ public abstract class QParserPlugin implements NamedListInitializedPlugin, SolrI
|
||||
map.put(GraphTermsQParserPlugin.NAME, GraphTermsQParserPlugin.class);
|
||||
map.put(IGainTermsQParserPlugin.NAME, IGainTermsQParserPlugin.class);
|
||||
map.put(TextLogisticRegressionQParserPlugin.NAME, TextLogisticRegressionQParserPlugin.class);
|
||||
map.put(SignificantTermsQParserPlugin.NAME, SignificantTermsQParserPlugin.class);
|
||||
|
||||
standardPlugins = Collections.unmodifiableMap(map);
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,260 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.search;
|
||||
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.TreeSet;
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.index.PostingsEnum;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.TermsEnum;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.SparseFixedBitSet;
|
||||
import org.apache.solr.common.params.SolrParams;
|
||||
import org.apache.solr.common.util.NamedList;
|
||||
import org.apache.solr.handler.component.ResponseBuilder;
|
||||
import org.apache.solr.request.SolrQueryRequest;
|
||||
|
||||
public class SignificantTermsQParserPlugin extends QParserPlugin {
|
||||
|
||||
public static final String NAME = "sigificantTerms";
|
||||
|
||||
@Override
|
||||
public QParser createParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
|
||||
return new SignifcantTermsQParser(qstr, localParams, params, req);
|
||||
}
|
||||
|
||||
private static class SignifcantTermsQParser extends QParser {
|
||||
|
||||
public SignifcantTermsQParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
|
||||
super(qstr, localParams, params, req);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Query parse() throws SyntaxError {
|
||||
String field = getParam("field");
|
||||
int numTerms = Integer.parseInt(params.get("numTerms", "20"));
|
||||
float minDocs = Float.parseFloat(params.get("minDocFreq", "5"));
|
||||
float maxDocs = Float.parseFloat(params.get("maxDocFreq", ".3"));
|
||||
int minTermLength = Integer.parseInt(params.get("minTermLength", "4"));
|
||||
return new SignificantTermsQuery(field, numTerms, minDocs, maxDocs, minTermLength);
|
||||
}
|
||||
}
|
||||
|
||||
private static class SignificantTermsQuery extends AnalyticsQuery {
|
||||
|
||||
private String field;
|
||||
private int numTerms;
|
||||
private float maxDocs;
|
||||
private float minDocs;
|
||||
private int minTermLength;
|
||||
|
||||
public SignificantTermsQuery(String field, int numTerms, float minDocs, float maxDocs, int minTermLength) {
|
||||
this.field = field;
|
||||
this.numTerms = numTerms;
|
||||
this.minDocs = minDocs;
|
||||
this.maxDocs = maxDocs;
|
||||
this.minTermLength = minTermLength;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public DelegatingCollector getAnalyticsCollector(ResponseBuilder rb, IndexSearcher searcher) {
|
||||
return new SignifcantTermsCollector(rb, searcher, field, numTerms, minDocs, maxDocs, minTermLength);
|
||||
}
|
||||
}
|
||||
|
||||
private static class SignifcantTermsCollector extends DelegatingCollector {
|
||||
|
||||
private String field;
|
||||
private IndexSearcher searcher;
|
||||
private ResponseBuilder rb;
|
||||
private int numTerms;
|
||||
private SparseFixedBitSet docs;
|
||||
private int numDocs;
|
||||
private float minDocs;
|
||||
private float maxDocs;
|
||||
private int count;
|
||||
private int minTermLength;
|
||||
private int highestCollected;
|
||||
|
||||
public SignifcantTermsCollector(ResponseBuilder rb, IndexSearcher searcher, String field, int numTerms, float minDocs, float maxDocs, int minTermLength) {
|
||||
this.rb = rb;
|
||||
this.searcher = searcher;
|
||||
this.field = field;
|
||||
this.numTerms = numTerms;
|
||||
this.docs = new SparseFixedBitSet(searcher.getIndexReader().maxDoc());
|
||||
this.numDocs = searcher.getIndexReader().numDocs();
|
||||
this.minDocs = minDocs;
|
||||
this.maxDocs = maxDocs;
|
||||
this.minTermLength = minTermLength;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doSetNextReader(LeafReaderContext context) throws IOException {
|
||||
super.doSetNextReader(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void collect(int doc) throws IOException {
|
||||
super.collect(doc);
|
||||
highestCollected = context.docBase + doc;
|
||||
docs.set(highestCollected);
|
||||
++count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() throws IOException {
|
||||
List<String> outTerms = new ArrayList();
|
||||
List<Integer> outFreq = new ArrayList();
|
||||
List<Integer> outQueryFreq = new ArrayList();
|
||||
List<Double> scores = new ArrayList();
|
||||
|
||||
NamedList<Integer> allFreq = new NamedList();
|
||||
NamedList<Integer> allQueryFreq = new NamedList();
|
||||
|
||||
rb.rsp.add("numDocs", numDocs);
|
||||
rb.rsp.add("resultCount", count);
|
||||
rb.rsp.add("sterms", outTerms);
|
||||
rb.rsp.add("scores", scores);
|
||||
rb.rsp.add("docFreq", outFreq);
|
||||
rb.rsp.add("queryDocFreq", outQueryFreq);
|
||||
|
||||
//TODO: Use a priority queue
|
||||
TreeSet<TermWithScore> topTerms = new TreeSet<>();
|
||||
|
||||
Terms terms = MultiFields.getFields(searcher.getIndexReader()).terms(field);
|
||||
TermsEnum termsEnum = terms.iterator();
|
||||
BytesRef term;
|
||||
PostingsEnum postingsEnum = null;
|
||||
|
||||
while ((term = termsEnum.next()) != null) {
|
||||
int docFreq = termsEnum.docFreq();
|
||||
|
||||
if(minDocs < 1.0) {
|
||||
if((float)docFreq/numDocs < minDocs) {
|
||||
continue;
|
||||
}
|
||||
} else if(docFreq < minDocs) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if(maxDocs < 1.0) {
|
||||
if((float)docFreq/numDocs > maxDocs) {
|
||||
continue;
|
||||
}
|
||||
} else if(docFreq > maxDocs) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if(term.length < minTermLength) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int tf = 0;
|
||||
postingsEnum = termsEnum.postings(postingsEnum);
|
||||
|
||||
POSTINGS:
|
||||
while (postingsEnum.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
int docId = postingsEnum.docID();
|
||||
|
||||
if(docId > highestCollected) {
|
||||
break POSTINGS;
|
||||
}
|
||||
|
||||
if (docs.get(docId)) {
|
||||
++tf;
|
||||
}
|
||||
}
|
||||
|
||||
if(tf == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float score = (float)Math.log(tf) * (float) (Math.log(((float)(numDocs + 1)) / (docFreq + 1)) + 1.0);
|
||||
|
||||
String t = term.utf8ToString();
|
||||
allFreq.add(t, docFreq);
|
||||
allQueryFreq.add(t, tf);
|
||||
|
||||
if (topTerms.size() < numTerms) {
|
||||
topTerms.add(new TermWithScore(term.utf8ToString(), score));
|
||||
} else {
|
||||
if (topTerms.first().score < score) {
|
||||
topTerms.pollFirst();
|
||||
topTerms.add(new TermWithScore(term.utf8ToString(), score));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (TermWithScore topTerm : topTerms) {
|
||||
outTerms.add(topTerm.term);
|
||||
scores.add(topTerm.score);
|
||||
outFreq.add(allFreq.get(topTerm.term));
|
||||
outQueryFreq.add(allQueryFreq.get(topTerm.term));
|
||||
}
|
||||
|
||||
if (this.delegate instanceof DelegatingCollector) {
|
||||
((DelegatingCollector) this.delegate).finish();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static class TermWithScore implements Comparable<TermWithScore>{
|
||||
public final String term;
|
||||
public final double score;
|
||||
|
||||
public TermWithScore(String term, double score) {
|
||||
this.term = term;
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return term.hashCode();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (obj == null) return false;
|
||||
if (obj.getClass() != getClass()) return false;
|
||||
TermWithScore other = (TermWithScore) obj;
|
||||
return other.term.equals(this.term);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(TermWithScore o) {
|
||||
int cmp = Double.compare(this.score, o.score);
|
||||
if (cmp == 0) {
|
||||
return this.term.compareTo(o.term);
|
||||
} else {
|
||||
return cmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -193,6 +193,15 @@ public class QueryEqualityTest extends SolrTestCaseJ4 {
|
||||
}
|
||||
}
|
||||
|
||||
public void testSignificantTermsQuery() throws Exception {
|
||||
SolrQueryRequest req = req("q", "*:*");
|
||||
try {
|
||||
assertQueryEquals("sigificantTerms", req, "{!sigificantTerms}");
|
||||
} finally {
|
||||
req.close();
|
||||
}
|
||||
}
|
||||
|
||||
public void testQuerySwitch() throws Exception {
|
||||
SolrQueryRequest req = req("myXXX", "XXX",
|
||||
"myField", "foo_s",
|
||||
|
@ -0,0 +1,444 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.client.solrj.io.stream;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Future;
|
||||
|
||||
import org.apache.solr.client.solrj.impl.CloudSolrClient;
|
||||
import org.apache.solr.client.solrj.impl.HttpSolrClient;
|
||||
import org.apache.solr.client.solrj.io.SolrClientCache;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.comp.StreamComparator;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
import org.apache.solr.client.solrj.request.QueryRequest;
|
||||
import org.apache.solr.client.solrj.response.QueryResponse;
|
||||
import org.apache.solr.common.cloud.ClusterState;
|
||||
import org.apache.solr.common.cloud.Replica;
|
||||
import org.apache.solr.common.cloud.Slice;
|
||||
import org.apache.solr.common.cloud.ZkCoreNodeProps;
|
||||
import org.apache.solr.common.cloud.ZkStateReader;
|
||||
import org.apache.solr.common.params.ModifiableSolrParams;
|
||||
import org.apache.solr.common.util.ExecutorUtil;
|
||||
import org.apache.solr.common.util.NamedList;
|
||||
import org.apache.solr.common.util.SolrjNamedThreadFactory;
|
||||
|
||||
public class SignificantTermsStream extends TupleStream implements Expressible{
|
||||
|
||||
private static final long serialVersionUID = 1;
|
||||
|
||||
protected String zkHost;
|
||||
protected String collection;
|
||||
protected Map<String,String> params;
|
||||
protected Iterator<Tuple> tupleIterator;
|
||||
protected String field;
|
||||
protected int numTerms;
|
||||
protected float minDocFreq;
|
||||
protected float maxDocFreq;
|
||||
protected int minTermLength;
|
||||
|
||||
protected transient SolrClientCache cache;
|
||||
protected transient boolean isCloseCache;
|
||||
protected transient CloudSolrClient cloudSolrClient;
|
||||
|
||||
protected transient StreamContext streamContext;
|
||||
protected ExecutorService executorService;
|
||||
|
||||
|
||||
public SignificantTermsStream(String zkHost,
|
||||
String collectionName,
|
||||
Map params,
|
||||
String field,
|
||||
float minDocFreq,
|
||||
float maxDocFreq,
|
||||
int minTermLength,
|
||||
int numTerms) throws IOException {
|
||||
|
||||
init(collectionName,
|
||||
zkHost,
|
||||
params,
|
||||
field,
|
||||
minDocFreq,
|
||||
maxDocFreq,
|
||||
minTermLength,
|
||||
numTerms);
|
||||
}
|
||||
|
||||
public SignificantTermsStream(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
// grab all parameters out
|
||||
String collectionName = factory.getValueOperand(expression, 0);
|
||||
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
|
||||
StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");
|
||||
|
||||
// Validate there are no unknown parameters - zkHost and alias are namedParameter so we don't need to count it twice
|
||||
if(expression.getParameters().size() != 1 + namedParams.size()){
|
||||
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - unknown operands found",expression));
|
||||
}
|
||||
|
||||
// Collection Name
|
||||
if(null == collectionName){
|
||||
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression));
|
||||
}
|
||||
|
||||
// Named parameters - passed directly to solr as solrparams
|
||||
if(0 == namedParams.size()){
|
||||
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one named parameter expected. eg. 'q=*:*'",expression));
|
||||
}
|
||||
|
||||
Map<String,String> params = new HashMap<String,String>();
|
||||
for(StreamExpressionNamedParameter namedParam : namedParams){
|
||||
if(!namedParam.getName().equals("zkHost")) {
|
||||
params.put(namedParam.getName(), namedParam.getParameter().toString().trim());
|
||||
}
|
||||
}
|
||||
|
||||
String fieldParam = params.get("field");
|
||||
if(fieldParam != null) {
|
||||
params.remove("field");
|
||||
} else {
|
||||
throw new IOException("field param cannot be null for SignificantTermsStream");
|
||||
}
|
||||
|
||||
String numTermsParam = params.get("limit");
|
||||
int numTerms = 20;
|
||||
if(numTermsParam != null) {
|
||||
numTerms = Integer.parseInt(numTermsParam);
|
||||
params.remove("limit");
|
||||
}
|
||||
|
||||
String minTermLengthParam = params.get("minTermLength");
|
||||
int minTermLength = 4;
|
||||
if(minTermLengthParam != null) {
|
||||
minTermLength = Integer.parseInt(minTermLengthParam);
|
||||
params.remove("minTermLength");
|
||||
}
|
||||
|
||||
|
||||
String minDocFreqParam = params.get("minDocFreq");
|
||||
float minDocFreq = 5.0F;
|
||||
if(minDocFreqParam != null) {
|
||||
minDocFreq = Float.parseFloat(minDocFreqParam);
|
||||
params.remove("minDocFreq");
|
||||
}
|
||||
|
||||
String maxDocFreqParam = params.get("maxDocFreq");
|
||||
float maxDocFreq = .3F;
|
||||
if(maxDocFreqParam != null) {
|
||||
maxDocFreq = Float.parseFloat(maxDocFreqParam);
|
||||
params.remove("maxDocFreq");
|
||||
}
|
||||
|
||||
|
||||
// zkHost, optional - if not provided then will look into factory list to get
|
||||
String zkHost = null;
|
||||
if(null == zkHostExpression){
|
||||
zkHost = factory.getCollectionZkHost(collectionName);
|
||||
}
|
||||
else if(zkHostExpression.getParameter() instanceof StreamExpressionValue){
|
||||
zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue();
|
||||
}
|
||||
if(null == zkHost){
|
||||
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName));
|
||||
}
|
||||
|
||||
// We've got all the required items
|
||||
init(collectionName, zkHost, params, fieldParam, minDocFreq, maxDocFreq, minTermLength, numTerms);
|
||||
}
|
||||
|
||||
@Override
|
||||
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
|
||||
// functionName(collectionName, param1, param2, ..., paramN, sort="comp", [aliases="field=alias,..."])
|
||||
|
||||
// function name
|
||||
StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
|
||||
|
||||
// collection
|
||||
expression.addParameter(collection);
|
||||
|
||||
// parameters
|
||||
for(Map.Entry<String,String> param : params.entrySet()){
|
||||
expression.addParameter(new StreamExpressionNamedParameter(param.getKey(), param.getValue()));
|
||||
}
|
||||
|
||||
expression.addParameter(new StreamExpressionNamedParameter("field", field));
|
||||
expression.addParameter(new StreamExpressionNamedParameter("minDocFreq", Float.toString(minDocFreq)));
|
||||
expression.addParameter(new StreamExpressionNamedParameter("maxDocFreq", Float.toString(maxDocFreq)));
|
||||
expression.addParameter(new StreamExpressionNamedParameter("numTerms", String.valueOf(numTerms)));
|
||||
expression.addParameter(new StreamExpressionNamedParameter("minTermLength", String.valueOf(minTermLength)));
|
||||
|
||||
// zkHost
|
||||
expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost));
|
||||
|
||||
return expression;
|
||||
}
|
||||
|
||||
private void init(String collectionName,
|
||||
String zkHost,
|
||||
Map params,
|
||||
String field,
|
||||
float minDocFreq,
|
||||
float maxDocFreq,
|
||||
int minTermLength,
|
||||
int numTerms) throws IOException {
|
||||
this.zkHost = zkHost;
|
||||
this.collection = collectionName;
|
||||
this.params = params;
|
||||
this.field = field;
|
||||
this.minDocFreq = minDocFreq;
|
||||
this.maxDocFreq = maxDocFreq;
|
||||
this.numTerms = numTerms;
|
||||
this.minTermLength = minTermLength;
|
||||
}
|
||||
|
||||
public void setStreamContext(StreamContext context) {
|
||||
this.cache = context.getSolrClientCache();
|
||||
this.streamContext = context;
|
||||
}
|
||||
|
||||
public void open() throws IOException {
|
||||
if (cache == null) {
|
||||
isCloseCache = true;
|
||||
cache = new SolrClientCache();
|
||||
} else {
|
||||
isCloseCache = false;
|
||||
}
|
||||
|
||||
this.cloudSolrClient = this.cache.getCloudSolrClient(zkHost);
|
||||
this.executorService = ExecutorUtil.newMDCAwareCachedThreadPool(new SolrjNamedThreadFactory("FeaturesSelectionStream"));
|
||||
}
|
||||
|
||||
public List<TupleStream> children() {
|
||||
return null;
|
||||
}
|
||||
|
||||
private List<String> getShardUrls() throws IOException {
|
||||
try {
|
||||
ZkStateReader zkStateReader = cloudSolrClient.getZkStateReader();
|
||||
|
||||
Collection<Slice> slices = CloudSolrStream.getSlices(this.collection, zkStateReader, false);
|
||||
|
||||
ClusterState clusterState = zkStateReader.getClusterState();
|
||||
Set<String> liveNodes = clusterState.getLiveNodes();
|
||||
|
||||
List<String> baseUrls = new ArrayList<>();
|
||||
for(Slice slice : slices) {
|
||||
Collection<Replica> replicas = slice.getReplicas();
|
||||
List<Replica> shuffler = new ArrayList<>();
|
||||
for(Replica replica : replicas) {
|
||||
if(replica.getState() == Replica.State.ACTIVE && liveNodes.contains(replica.getNodeName())) {
|
||||
shuffler.add(replica);
|
||||
}
|
||||
}
|
||||
|
||||
Collections.shuffle(shuffler, new Random());
|
||||
Replica rep = shuffler.get(0);
|
||||
ZkCoreNodeProps zkProps = new ZkCoreNodeProps(rep);
|
||||
String url = zkProps.getCoreUrl();
|
||||
baseUrls.add(url);
|
||||
}
|
||||
|
||||
return baseUrls;
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new IOException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private List<Future<NamedList>> callShards(List<String> baseUrls) throws IOException {
|
||||
|
||||
List<Future<NamedList>> futures = new ArrayList<>();
|
||||
for (String baseUrl : baseUrls) {
|
||||
SignificantTermsCall lc = new SignificantTermsCall(baseUrl,
|
||||
this.params,
|
||||
this.field,
|
||||
this.minDocFreq,
|
||||
this.maxDocFreq,
|
||||
this.minTermLength,
|
||||
this.numTerms);
|
||||
|
||||
Future<NamedList> future = executorService.submit(lc);
|
||||
futures.add(future);
|
||||
}
|
||||
|
||||
return futures;
|
||||
}
|
||||
|
||||
public void close() throws IOException {
|
||||
if (isCloseCache) {
|
||||
cache.close();
|
||||
}
|
||||
|
||||
executorService.shutdown();
|
||||
}
|
||||
|
||||
/** Return the stream sort - ie, the order in which records are returned */
|
||||
public StreamComparator getStreamSort(){
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation toExplanation(StreamFactory factory) throws IOException {
|
||||
return new StreamExplanation(getStreamNodeId().toString())
|
||||
.withFunctionName(factory.getFunctionName(this.getClass()))
|
||||
.withImplementingClass(this.getClass().getName())
|
||||
.withExpressionType(Explanation.ExpressionType.STREAM_DECORATOR)
|
||||
.withExpression(toExpression(factory).toString());
|
||||
}
|
||||
|
||||
public Tuple read() throws IOException {
|
||||
try {
|
||||
if (tupleIterator == null) {
|
||||
Map<String, int[]> mergeFreqs = new HashMap<>();
|
||||
long numDocs = 0;
|
||||
long resultCount = 0;
|
||||
for (Future<NamedList> getTopTermsCall : callShards(getShardUrls())) {
|
||||
NamedList resp = getTopTermsCall.get();
|
||||
|
||||
List<String> terms = (List<String>)resp.get("sterms");
|
||||
List<Integer> docFreqs = (List<Integer>)resp.get("docFreq");
|
||||
List<Integer> queryDocFreqs = (List<Integer>)resp.get("queryDocFreq");
|
||||
numDocs += (Integer)resp.get("numDocs");
|
||||
resultCount += (Integer)resp.get("resultCount");
|
||||
|
||||
for (int i = 0; i < terms.size(); i++) {
|
||||
String term = terms.get(i);
|
||||
int docFreq = docFreqs.get(i);
|
||||
int queryDocFreq = queryDocFreqs.get(i);
|
||||
if(!mergeFreqs.containsKey(term)) {
|
||||
mergeFreqs.put(term, new int[2]);
|
||||
}
|
||||
|
||||
int[] freqs = mergeFreqs.get(term);
|
||||
freqs[0] += docFreq;
|
||||
freqs[1] += queryDocFreq;
|
||||
}
|
||||
}
|
||||
|
||||
List<Map> maps = new ArrayList();
|
||||
|
||||
for(String term : mergeFreqs.keySet() ) {
|
||||
int[] freqs = mergeFreqs.get(term);
|
||||
Map map = new HashMap();
|
||||
map.put("term", term);
|
||||
map.put("background", freqs[0]);
|
||||
map.put("foreground", freqs[1]);
|
||||
|
||||
float score = (float)Math.log(freqs[1]) * (float) (Math.log(((float)(numDocs + 1)) / (freqs[0] + 1)) + 1.0);
|
||||
|
||||
map.put("score", score);
|
||||
maps.add(map);
|
||||
}
|
||||
|
||||
Collections.sort(maps, new ScoreComp());
|
||||
List<Tuple> tuples = new ArrayList();
|
||||
for (Map map : maps) {
|
||||
if (tuples.size() == numTerms) break;
|
||||
tuples.add(new Tuple(map));
|
||||
}
|
||||
|
||||
Map map = new HashMap();
|
||||
map.put("EOF", true);
|
||||
tuples.add(new Tuple(map));
|
||||
tupleIterator = tuples.iterator();
|
||||
}
|
||||
|
||||
return tupleIterator.next();
|
||||
} catch(Exception e) {
|
||||
throw new IOException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private class ScoreComp implements Comparator<Map> {
|
||||
public int compare(Map a, Map b) {
|
||||
Float scorea = (Float)a.get("score");
|
||||
Float scoreb = (Float)b.get("score");
|
||||
return scoreb.compareTo(scorea);
|
||||
}
|
||||
}
|
||||
|
||||
protected class SignificantTermsCall implements Callable<NamedList> {
|
||||
|
||||
private String baseUrl;
|
||||
private String field;
|
||||
private float minDocFreq;
|
||||
private float maxDocFreq;
|
||||
private int numTerms;
|
||||
private int minTermLength;
|
||||
private Map<String, String> paramsMap;
|
||||
|
||||
public SignificantTermsCall(String baseUrl,
|
||||
Map<String, String> paramsMap,
|
||||
String field,
|
||||
float minDocFreq,
|
||||
float maxDocFreq,
|
||||
int minTermLength,
|
||||
int numTerms) {
|
||||
|
||||
this.baseUrl = baseUrl;
|
||||
this.field = field;
|
||||
this.minDocFreq = minDocFreq;
|
||||
this.maxDocFreq = maxDocFreq;
|
||||
this.paramsMap = paramsMap;
|
||||
this.numTerms = numTerms;
|
||||
this.minTermLength = minTermLength;
|
||||
}
|
||||
|
||||
public NamedList<Double> call() throws Exception {
|
||||
ModifiableSolrParams params = new ModifiableSolrParams();
|
||||
HttpSolrClient solrClient = cache.getHttpSolrClient(baseUrl);
|
||||
|
||||
params.add("distrib", "false");
|
||||
params.add("fq","{!sigificantTerms}");
|
||||
|
||||
for(String key : paramsMap.keySet()) {
|
||||
params.add(key, paramsMap.get(key));
|
||||
}
|
||||
|
||||
params.add("minDocFreq", Float.toString(minDocFreq));
|
||||
params.add("maxDocFreq", Float.toString(maxDocFreq));
|
||||
params.add("minTermLength", Integer.toString(minTermLength));
|
||||
params.add("field", field);
|
||||
params.add("numTerms", String.valueOf(numTerms*3));
|
||||
|
||||
QueryRequest request= new QueryRequest(params);
|
||||
QueryResponse response = request.process(solrClient);
|
||||
NamedList res = response.getResponse();
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
@ -96,6 +96,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
||||
} else {
|
||||
collection = COLLECTIONORALIAS;
|
||||
}
|
||||
|
||||
CollectionAdminRequest.createCollection(collection, "conf", 2, 1).process(cluster.getSolrClient());
|
||||
AbstractDistribZkTestBase.waitForRecoveriesToFinish(collection, cluster.getSolrClient().getZkStateReader(),
|
||||
false, true, TIMEOUT);
|
||||
@ -4707,6 +4708,140 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
||||
CollectionAdminRequest.deleteCollection("destinationCollection").process(cluster.getSolrClient());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testSignificantTermsStream() throws Exception {
|
||||
|
||||
Assume.assumeTrue(!useAlias);
|
||||
|
||||
UpdateRequest updateRequest = new UpdateRequest();
|
||||
for (int i = 0; i < 5000; i++) {
|
||||
updateRequest.add(id, "a"+i, "test_t", "a b c d m l");
|
||||
}
|
||||
|
||||
for (int i = 0; i < 5000; i++) {
|
||||
updateRequest.add(id, "b"+i, "test_t", "a b e f");
|
||||
}
|
||||
|
||||
for (int i = 0; i < 900; i++) {
|
||||
updateRequest.add(id, "c"+i, "test_t", "c");
|
||||
}
|
||||
|
||||
for (int i = 0; i < 600; i++) {
|
||||
updateRequest.add(id, "d"+i, "test_t", "d");
|
||||
}
|
||||
|
||||
for (int i = 0; i < 500; i++) {
|
||||
updateRequest.add(id, "e"+i, "test_t", "m");
|
||||
}
|
||||
|
||||
updateRequest.commit(cluster.getSolrClient(), COLLECTIONORALIAS);
|
||||
|
||||
TupleStream stream;
|
||||
List<Tuple> tuples;
|
||||
|
||||
StreamFactory factory = new StreamFactory()
|
||||
.withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress())
|
||||
.withFunctionName("significantTerms", SignificantTermsStream.class);
|
||||
|
||||
String significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=3, minTermLength=1, maxDocFreq=\".5\")";
|
||||
stream = factory.constructStream(significantTerms);
|
||||
tuples = getTuples(stream);
|
||||
|
||||
assert(tuples.size() == 3);
|
||||
assertTrue(tuples.get(0).get("term").equals("l"));
|
||||
assertTrue(tuples.get(0).getLong("background") == 5000);
|
||||
assertTrue(tuples.get(0).getLong("foreground") == 5000);
|
||||
|
||||
|
||||
assertTrue(tuples.get(1).get("term").equals("m"));
|
||||
assertTrue(tuples.get(1).getLong("background") == 5500);
|
||||
assertTrue(tuples.get(1).getLong("foreground") == 5000);
|
||||
|
||||
assertTrue(tuples.get(2).get("term").equals("d"));
|
||||
assertTrue(tuples.get(2).getLong("background") == 5600);
|
||||
assertTrue(tuples.get(2).getLong("foreground") == 5000);
|
||||
|
||||
//Test maxDocFreq
|
||||
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=3, maxDocFreq=2650, minTermLength=1)";
|
||||
stream = factory.constructStream(significantTerms);
|
||||
tuples = getTuples(stream);
|
||||
|
||||
assert(tuples.size() == 1);
|
||||
assertTrue(tuples.get(0).get("term").equals("l"));
|
||||
assertTrue(tuples.get(0).getLong("background") == 5000);
|
||||
assertTrue(tuples.get(0).getLong("foreground") == 5000);
|
||||
|
||||
//Test maxDocFreq percentage
|
||||
|
||||
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=3, maxDocFreq=\".45\", minTermLength=1)";
|
||||
stream = factory.constructStream(significantTerms);
|
||||
tuples = getTuples(stream);
|
||||
assert(tuples.size() == 1);
|
||||
assertTrue(tuples.get(0).get("term").equals("l"));
|
||||
assertTrue(tuples.get(0).getLong("background") == 5000);
|
||||
assertTrue(tuples.get(0).getLong("foreground") == 5000);
|
||||
|
||||
|
||||
//Test min doc freq
|
||||
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=3, minDocFreq=\"2700\", minTermLength=1, maxDocFreq=\".5\")";
|
||||
stream = factory.constructStream(significantTerms);
|
||||
tuples = getTuples(stream);
|
||||
|
||||
assert(tuples.size() == 3);
|
||||
|
||||
assertTrue(tuples.get(0).get("term").equals("m"));
|
||||
assertTrue(tuples.get(0).getLong("background") == 5500);
|
||||
assertTrue(tuples.get(0).getLong("foreground") == 5000);
|
||||
|
||||
assertTrue(tuples.get(1).get("term").equals("d"));
|
||||
assertTrue(tuples.get(1).getLong("background") == 5600);
|
||||
assertTrue(tuples.get(1).getLong("foreground") == 5000);
|
||||
|
||||
assertTrue(tuples.get(2).get("term").equals("c"));
|
||||
assertTrue(tuples.get(2).getLong("background") == 5900);
|
||||
assertTrue(tuples.get(2).getLong("foreground") == 5000);
|
||||
|
||||
|
||||
//Test min doc freq percent
|
||||
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=3, minDocFreq=\".478\", minTermLength=1, maxDocFreq=\".5\")";
|
||||
stream = factory.constructStream(significantTerms);
|
||||
tuples = getTuples(stream);
|
||||
|
||||
assert(tuples.size() == 1);
|
||||
|
||||
assertTrue(tuples.get(0).get("term").equals("c"));
|
||||
assertTrue(tuples.get(0).getLong("background") == 5900);
|
||||
assertTrue(tuples.get(0).getLong("foreground") == 5000);
|
||||
|
||||
|
||||
//Test limit
|
||||
|
||||
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=2, minDocFreq=\"2700\", minTermLength=1, maxDocFreq=\".5\")";
|
||||
stream = factory.constructStream(significantTerms);
|
||||
tuples = getTuples(stream);
|
||||
|
||||
assert(tuples.size() == 2);
|
||||
|
||||
assertTrue(tuples.get(0).get("term").equals("m"));
|
||||
assertTrue(tuples.get(0).getLong("background") == 5500);
|
||||
assertTrue(tuples.get(0).getLong("foreground") == 5000);
|
||||
|
||||
assertTrue(tuples.get(1).get("term").equals("d"));
|
||||
assertTrue(tuples.get(1).getLong("background") == 5600);
|
||||
assertTrue(tuples.get(1).getLong("foreground") == 5000);
|
||||
|
||||
//Test term length
|
||||
|
||||
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=2, minDocFreq=\"2700\", minTermLength=2)";
|
||||
stream = factory.constructStream(significantTerms);
|
||||
tuples = getTuples(stream);
|
||||
assert(tuples.size() == 0);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testComplementStream() throws Exception {
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user