SOLR-9252: Feature selection and logistic regression on text

This commit is contained in:
jbernste 2016-08-03 11:12:57 -04:00
parent 9fc4624853
commit 87938e00e9
14 changed files with 2076 additions and 8 deletions

View File

@ -122,6 +122,8 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("intersect", IntersectStream.class)
.withFunctionName("complement", ComplementStream.class)
.withFunctionName("sort", SortStream.class)
.withFunctionName("train", TextLogitStream.class)
.withFunctionName("features", FeaturesSelectionStream.class)
.withFunctionName("daemon", DaemonStream.class)
.withFunctionName("shortestPath", ShortestPathStream.class)
.withFunctionName("gatherNodes", GatherNodesStream.class)

View File

@ -0,0 +1,239 @@
package org.apache.solr.search;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.util.TreeSet;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.request.SolrQueryRequest;
public class IGainTermsQParserPlugin extends QParserPlugin {
public static final String NAME = "igain";
@Override
public QParser createParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
return new IGainTermsQParser(qstr, localParams, params, req);
}
private static class IGainTermsQParser extends QParser {
public IGainTermsQParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
super(qstr, localParams, params, req);
}
@Override
public Query parse() throws SyntaxError {
String field = getParam("field");
String outcome = getParam("outcome");
int numTerms = Integer.parseInt(getParam("numTerms"));
int positiveLabel = Integer.parseInt(getParam("positiveLabel"));
return new IGainTermsQuery(field, outcome, positiveLabel, numTerms);
}
}
private static class IGainTermsQuery extends AnalyticsQuery {
private String field;
private String outcome;
private int numTerms;
private int positiveLabel;
public IGainTermsQuery(String field, String outcome, int positiveLabel, int numTerms) {
this.field = field;
this.outcome = outcome;
this.numTerms = numTerms;
this.positiveLabel = positiveLabel;
}
@Override
public DelegatingCollector getAnalyticsCollector(ResponseBuilder rb, IndexSearcher searcher) {
return new IGainTermsCollector(rb, searcher, field, outcome, positiveLabel, numTerms);
}
}
private static class IGainTermsCollector extends DelegatingCollector {
private String field;
private String outcome;
private IndexSearcher searcher;
private ResponseBuilder rb;
private int positiveLabel;
private int numTerms;
private int count;
private NumericDocValues leafOutcomeValue;
private SparseFixedBitSet positiveSet;
private SparseFixedBitSet negativeSet;
private int numPositiveDocs;
public IGainTermsCollector(ResponseBuilder rb, IndexSearcher searcher, String field, String outcome, int positiveLabel, int numTerms) {
this.rb = rb;
this.searcher = searcher;
this.field = field;
this.outcome = outcome;
this.positiveSet = new SparseFixedBitSet(searcher.getIndexReader().maxDoc());
this.negativeSet = new SparseFixedBitSet(searcher.getIndexReader().maxDoc());
this.numTerms = numTerms;
this.positiveLabel = positiveLabel;
}
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
super.doSetNextReader(context);
LeafReader reader = context.reader();
leafOutcomeValue = reader.getNumericDocValues(outcome);
}
@Override
public void collect(int doc) throws IOException {
super.collect(doc);
++count;
if (leafOutcomeValue.get(doc) == positiveLabel) {
positiveSet.set(context.docBase + doc);
numPositiveDocs++;
} else {
negativeSet.set(context.docBase + doc);
}
}
@Override
public void finish() throws IOException {
NamedList<Double> analytics = new NamedList<Double>();
NamedList<Integer> topFreq = new NamedList();
NamedList<Integer> allFreq = new NamedList();
rb.rsp.add("featuredTerms", analytics);
rb.rsp.add("docFreq", topFreq);
rb.rsp.add("numDocs", count);
TreeSet<TermWithScore> topTerms = new TreeSet<>();
double numDocs = count;
double pc = numPositiveDocs / numDocs;
double entropyC = binaryEntropy(pc);
Terms terms = MultiFields.getFields(searcher.getIndexReader()).terms(field);
TermsEnum termsEnum = terms.iterator();
BytesRef term;
PostingsEnum postingsEnum = null;
while ((term = termsEnum.next()) != null) {
postingsEnum = termsEnum.postings(postingsEnum);
int xc = 0;
int nc = 0;
while (postingsEnum.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
if (positiveSet.get(postingsEnum.docID())) {
xc++;
} else if (negativeSet.get(postingsEnum.docID())) {
nc++;
}
}
int docFreq = xc+nc;
double entropyContainsTerm = binaryEntropy( (double) xc / docFreq );
double entropyNotContainsTerm = binaryEntropy( (double) (numPositiveDocs - xc) / (numDocs - docFreq + 1) );
double score = entropyC - ( (docFreq / numDocs) * entropyContainsTerm + (1.0 - docFreq / numDocs) * entropyNotContainsTerm);
topFreq.add(term.utf8ToString(), docFreq);
if (topTerms.size() < numTerms) {
topTerms.add(new TermWithScore(term.utf8ToString(), score));
} else {
if (topTerms.first().score < score) {
topTerms.pollFirst();
topTerms.add(new TermWithScore(term.utf8ToString(), score));
}
}
}
for (TermWithScore topTerm : topTerms) {
analytics.add(topTerm.term, topTerm.score);
topFreq.add(topTerm.term, allFreq.get(topTerm.term));
}
if (this.delegate instanceof DelegatingCollector) {
((DelegatingCollector) this.delegate).finish();
}
}
private double binaryEntropy(double prob) {
if (prob == 0 || prob == 1) return 0;
return (-1 * prob * Math.log(prob)) + (-1 * (1.0 - prob) * Math.log(1.0 - prob));
}
}
private static class TermWithScore implements Comparable<TermWithScore>{
public final String term;
public final double score;
public TermWithScore(String term, double score) {
this.term = term;
this.score = score;
}
@Override
public int hashCode() {
return term.hashCode();
}
@Override
public boolean equals(Object obj) {
if (obj == null) return false;
if (obj.getClass() != getClass()) return false;
TermWithScore other = (TermWithScore) obj;
return other.term.equals(this.term);
}
@Override
public int compareTo(TermWithScore o) {
int cmp = Double.compare(this.score, o.score);
if (cmp == 0) {
return this.term.compareTo(o.term);
} else {
return cmp;
}
}
}
}

View File

@ -16,6 +16,11 @@
*/
package org.apache.solr.search;
import java.net.URL;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrInfoMBean;
@ -26,11 +31,6 @@ import org.apache.solr.search.join.GraphQParserPlugin;
import org.apache.solr.search.mlt.MLTQParserPlugin;
import org.apache.solr.util.plugin.NamedListInitializedPlugin;
import java.net.URL;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
public abstract class QParserPlugin implements NamedListInitializedPlugin, SolrInfoMBean {
/** internal use - name of the default parser */
public static final String DEFAULT_QTYPE = LuceneQParserPlugin.NAME;
@ -77,6 +77,8 @@ public abstract class QParserPlugin implements NamedListInitializedPlugin, SolrI
map.put(GraphQParserPlugin.NAME, GraphQParserPlugin.class);
map.put(XmlQParserPlugin.NAME, XmlQParserPlugin.class);
map.put(GraphTermsQParserPlugin.NAME, GraphTermsQParserPlugin.class);
map.put(IGainTermsQParserPlugin.NAME, IGainTermsQParserPlugin.class);
map.put(TextLogisticRegressionQParserPlugin.NAME, TextLogisticRegressionQParserPlugin.class);
standardPlugins = Collections.unmodifiableMap(map);
}

View File

@ -0,0 +1,283 @@
package org.apache.solr.search;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.solr.client.solrj.io.ClassificationEvaluation;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.request.SolrQueryRequest;
/**
* Returns an AnalyticsQuery implementation that performs
* one Gradient Descent iteration of a result set to train a
* logistic regression model
*
* The TextLogitStream provides the parallel iterative framework for this class.
**/
public class TextLogisticRegressionQParserPlugin extends QParserPlugin {
public static final String NAME = "tlogit";
@Override
public void init(NamedList args) {
}
@Override
public QParser createParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
return new TextLogisticRegressionQParser(qstr, localParams, params, req);
}
private static class TextLogisticRegressionQParser extends QParser{
TextLogisticRegressionQParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
super(qstr, localParams, params, req);
}
public Query parse() {
String fs = params.get("feature");
String[] terms = params.get("terms").split(",");
String ws = params.get("weights");
String dfsStr = params.get("idfs");
int iteration = params.getInt("iteration");
String outcome = params.get("outcome");
int positiveLabel = params.getInt("positiveLabel", 1);
double threshold = params.getDouble("threshold", 0.5);
double alpha = params.getDouble("alpha", 0.01);
double[] idfs = new double[terms.length];
String[] idfsArr = dfsStr.split(",");
for (int i = 0; i < idfsArr.length; i++) {
idfs[i] = Double.parseDouble(idfsArr[i]);
}
double[] weights = new double[terms.length+1];
if(ws != null) {
String[] wa = ws.split(",");
for (int i = 0; i < wa.length; i++) {
weights[i] = Double.parseDouble(wa[i]);
}
} else {
for(int i=0; i<weights.length; i++) {
weights[i]= 1.0d;
}
}
TrainingParams input = new TrainingParams(fs, terms, idfs, outcome, weights, iteration, alpha, positiveLabel, threshold);
return new TextLogisticRegressionQuery(input);
}
}
private static class TextLogisticRegressionQuery extends AnalyticsQuery {
private TrainingParams trainingParams;
public TextLogisticRegressionQuery(TrainingParams trainingParams) {
this.trainingParams = trainingParams;
}
public DelegatingCollector getAnalyticsCollector(ResponseBuilder rbsp, IndexSearcher indexSearcher) {
return new TextLogisticRegressionCollector(rbsp, indexSearcher, trainingParams);
}
}
private static class TextLogisticRegressionCollector extends DelegatingCollector {
private TrainingParams trainingParams;
private LeafReader leafReader;
private double[] workingDeltas;
private ClassificationEvaluation classificationEvaluation;
private double[] weights;
private ResponseBuilder rbsp;
private NumericDocValues leafOutcomeValue;
private double totalError;
private SparseFixedBitSet positiveDocsSet;
private SparseFixedBitSet docsSet;
private IndexSearcher searcher;
TextLogisticRegressionCollector(ResponseBuilder rbsp, IndexSearcher searcher,
TrainingParams trainingParams) {
this.trainingParams = trainingParams;
this.workingDeltas = new double[trainingParams.weights.length];
this.weights = Arrays.copyOf(trainingParams.weights, trainingParams.weights.length);
this.rbsp = rbsp;
this.classificationEvaluation = new ClassificationEvaluation();
this.searcher = searcher;
positiveDocsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs());
docsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs());
}
public void doSetNextReader(LeafReaderContext context) throws IOException {
super.doSetNextReader(context);
leafReader = context.reader();
leafOutcomeValue = leafReader.getNumericDocValues(trainingParams.outcome);
}
public void collect(int doc) throws IOException{
int outcome = (int) leafOutcomeValue.get(doc);
outcome = trainingParams.positiveLabel == outcome? 1 : 0;
if (outcome == 1) {
positiveDocsSet.set(context.docBase + doc);
}
docsSet.set(context.docBase+doc);
}
public void finish() throws IOException {
Map<Integer, double[]> docVectors = new HashMap<>();
Terms terms = MultiFields.getFields(searcher.getIndexReader()).terms(trainingParams.feature);
TermsEnum termsEnum = terms.iterator();
PostingsEnum postingsEnum = null;
int termIndex = 0;
for (String termStr : trainingParams.terms) {
BytesRef term = new BytesRef(termStr);
if (termsEnum.seekExact(term)) {
postingsEnum = termsEnum.postings(postingsEnum);
while (postingsEnum.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
int docId = postingsEnum.docID();
if (docsSet.get(docId)) {
double[] vector = docVectors.get(docId);
if (vector == null) {
vector = new double[trainingParams.terms.length+1];
vector[0] = 1.0;
docVectors.put(docId, vector);
}
vector[termIndex + 1] = trainingParams.idfs[termIndex] * (1.0 + Math.log(postingsEnum.freq()));
}
}
}
termIndex++;
}
for (Map.Entry<Integer, double[]> entry : docVectors.entrySet()) {
double[] vector = entry.getValue();
int outcome = 0;
if (positiveDocsSet.get(entry.getKey())) {
outcome = 1;
}
double sig = sigmoid(sum(multiply(vector, weights)));
double error = sig - outcome;
double lastSig = sigmoid(sum(multiply(vector, trainingParams.weights)));
totalError += Math.abs(lastSig - outcome);
classificationEvaluation.count(outcome, lastSig >= trainingParams.threshold ? 1 : 0);
workingDeltas = multiply(error * trainingParams.alpha, vector);
for(int i = 0; i< workingDeltas.length; i++) {
weights[i] -= workingDeltas[i];
}
}
NamedList analytics = new NamedList();
rbsp.rsp.add("logit", analytics);
List<Double> outWeights = new ArrayList<>();
for(Double d : weights) {
outWeights.add(d);
}
analytics.add("weights", outWeights);
analytics.add("error", totalError);
analytics.add("evaluation", classificationEvaluation.toMap());
analytics.add("feature", trainingParams.feature);
analytics.add("positiveLabel", trainingParams.positiveLabel);
if(this.delegate instanceof DelegatingCollector) {
((DelegatingCollector)this.delegate).finish();
}
}
private double sigmoid(double in) {
double d = 1.0 / (1+Math.exp(-in));
return d;
}
private double[] multiply(double[] vals, double[] weights) {
for(int i = 0; i < vals.length; ++i) {
workingDeltas[i] = vals[i] * weights[i];
}
return workingDeltas;
}
private double[] multiply(double d, double[] vals) {
for(int i = 0; i<vals.length; ++i) {
workingDeltas[i] = vals[i] * d;
}
return workingDeltas;
}
private double sum(double[] vals) {
double d = 0.0d;
for(double val : vals) {
d += val;
}
return d;
}
}
private static class TrainingParams {
public final String feature;
public final String[] terms;
public final double[] idfs;
public final String outcome;
public final double[] weights;
public final int interation;
public final int positiveLabel;
public final double threshold;
public final double alpha;
public TrainingParams(String feature, String[] terms, double[] idfs, String outcome, double[] weights, int interation, double alpha, int positiveLabel, double threshold) {
this.feature = feature;
this.terms = terms;
this.idfs = idfs;
this.outcome = outcome;
this.weights = weights;
this.alpha = alpha;
this.interation = interation;
this.positiveLabel = positiveLabel;
this.threshold = threshold;
}
}
}

View File

@ -175,6 +175,24 @@ public class QueryEqualityTest extends SolrTestCaseJ4 {
}
}
public void testTlogitQuery() throws Exception {
SolrQueryRequest req = req("q", "*:*", "feature", "f", "terms","a,b,c", "weights", "100,200,300", "idfs","1,5,7","iteration","1", "outcome","a","positiveLabel","1");
try {
assertQueryEquals("tlogit", req, "{!tlogit}");
} finally {
req.close();
}
}
public void testIGainQuery() throws Exception {
SolrQueryRequest req = req("q", "*:*", "outcome", "b", "positiveLabel", "1", "field", "x", "numTerms","200");
try {
assertQueryEquals("igain", req, "{!igain}");
} finally {
req.close();
}
}
public void testQuerySwitch() throws Exception {
SolrQueryRequest req = req("myXXX", "XXX",
"myField", "foo_s",

View File

@ -0,0 +1,85 @@
package org.apache.solr.client.solrj.io;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.util.HashMap;
import java.util.Map;
public class ClassificationEvaluation {
private long truePositive;
private long falsePositive;
private long trueNegative;
private long falseNegative;
public void count(int actual, int predicted) {
if (predicted == 1) {
if (actual == 1) truePositive++;
else falsePositive++;
} else {
if (actual == 0) trueNegative++;
else falseNegative++;
}
}
public void putToMap(Map map) {
map.put("truePositive_i",truePositive);
map.put("trueNegative_i",trueNegative);
map.put("falsePositive_i",falsePositive);
map.put("falseNegative_i",falseNegative);
}
public Map toMap() {
HashMap map = new HashMap();
putToMap(map);
return map;
}
public static ClassificationEvaluation create(Map map) {
ClassificationEvaluation evaluation = new ClassificationEvaluation();
evaluation.addEvaluation(map);
return evaluation;
}
public void addEvaluation(Map map) {
this.truePositive += (long) map.get("truePositive_i");
this.trueNegative += (long) map.get("trueNegative_i");
this.falsePositive += (long) map.get("falsePositive_i");
this.falseNegative += (long) map.get("falseNegative_i");
}
public double getPrecision() {
if (truePositive + falsePositive == 0) return 0;
return (double) truePositive / (truePositive + falsePositive);
}
public double getRecall() {
if (truePositive + falseNegative == 0) return 0;
return (double) truePositive / (truePositive + falseNegative);
}
public double getF1() {
double precision = getPrecision();
double recall = getRecall();
if (precision + recall == 0) return 0;
return 2 * (precision * recall) / (precision + recall);
}
public double getAccuracy() {
return (double) (truePositive + trueNegative) / (truePositive + trueNegative + falseNegative + falsePositive);
}
}

View File

@ -0,0 +1,436 @@
package org.apache.solr.client.solrj.io.stream;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.stream.Stream;
import org.apache.solr.client.solrj.impl.CloudSolrClient;
import org.apache.solr.client.solrj.impl.HttpSolrClient;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.client.solrj.request.QueryRequest;
import org.apache.solr.client.solrj.response.QueryResponse;
import org.apache.solr.common.cloud.ClusterState;
import org.apache.solr.common.cloud.Replica;
import org.apache.solr.common.cloud.Slice;
import org.apache.solr.common.cloud.ZkCoreNodeProps;
import org.apache.solr.common.cloud.ZkStateReader;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.util.ExecutorUtil;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.common.util.SolrjNamedThreadFactory;
public class FeaturesSelectionStream extends TupleStream implements Expressible{
private static final long serialVersionUID = 1;
protected String zkHost;
protected String collection;
protected Map<String,String> params;
protected Iterator<Tuple> tupleIterator;
protected String field;
protected String outcome;
protected String featureSet;
protected int positiveLabel;
protected int numTerms;
protected transient SolrClientCache cache;
protected transient boolean isCloseCache;
protected transient CloudSolrClient cloudSolrClient;
protected transient StreamContext streamContext;
protected ExecutorService executorService;
public FeaturesSelectionStream(String zkHost,
String collectionName,
Map params,
String field,
String outcome,
String featureSet,
int positiveLabel,
int numTerms) throws IOException {
init(collectionName, zkHost, params, field, outcome, featureSet, positiveLabel, numTerms);
}
/**
* logit(collection, zkHost="", features="a,b,c,d,e,f,g", outcome="y", maxIteration="20")
**/
public FeaturesSelectionStream(StreamExpression expression, StreamFactory factory) throws IOException{
// grab all parameters out
String collectionName = factory.getValueOperand(expression, 0);
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");
// Validate there are no unknown parameters - zkHost and alias are namedParameter so we don't need to count it twice
if(expression.getParameters().size() != 1 + namedParams.size()){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - unknown operands found",expression));
}
// Collection Name
if(null == collectionName){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression));
}
// Named parameters - passed directly to solr as solrparams
if(0 == namedParams.size()){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one named parameter expected. eg. 'q=*:*'",expression));
}
Map<String,String> params = new HashMap<String,String>();
for(StreamExpressionNamedParameter namedParam : namedParams){
if(!namedParam.getName().equals("zkHost")) {
params.put(namedParam.getName(), namedParam.getParameter().toString().trim());
}
}
String fieldParam = params.get("field");
if(fieldParam != null) {
params.remove("field");
} else {
throw new IOException("field param cannot be null for FeaturesSelectionStream");
}
String outcomeParam = params.get("outcome");
if(outcomeParam != null) {
params.remove("outcome");
} else {
throw new IOException("outcome param cannot be null for FeaturesSelectionStream");
}
String featureSetParam = params.get("featureSet");
if(featureSetParam != null) {
params.remove("featureSet");
} else {
throw new IOException("featureSet param cannot be null for FeaturesSelectionStream");
}
String positiveLabelParam = params.get("positiveLabel");
int positiveLabel = 1;
if(positiveLabelParam != null) {
params.remove("positiveLabel");
positiveLabel = Integer.parseInt(positiveLabelParam);
}
String numTermsParam = params.get("numTerms");
int numTerms = 1;
if(numTermsParam != null) {
numTerms = Integer.parseInt(numTermsParam);
params.remove("numTerms");
} else {
throw new IOException("numTerms param cannot be null for FeaturesSelectionStream");
}
// zkHost, optional - if not provided then will look into factory list to get
String zkHost = null;
if(null == zkHostExpression){
zkHost = factory.getCollectionZkHost(collectionName);
}
else if(zkHostExpression.getParameter() instanceof StreamExpressionValue){
zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue();
}
if(null == zkHost){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName));
}
// We've got all the required items
init(collectionName, zkHost, params, fieldParam, outcomeParam, featureSetParam, positiveLabel, numTerms);
}
@Override
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
// functionName(collectionName, param1, param2, ..., paramN, sort="comp", [aliases="field=alias,..."])
// function name
StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
// collection
expression.addParameter(collection);
// parameters
for(Map.Entry<String,String> param : params.entrySet()){
expression.addParameter(new StreamExpressionNamedParameter(param.getKey(), param.getValue()));
}
expression.addParameter(new StreamExpressionNamedParameter("field", field));
expression.addParameter(new StreamExpressionNamedParameter("outcome", outcome));
expression.addParameter(new StreamExpressionNamedParameter("featureSet", featureSet));
expression.addParameter(new StreamExpressionNamedParameter("positiveLabel", String.valueOf(positiveLabel)));
expression.addParameter(new StreamExpressionNamedParameter("numTerms", String.valueOf(numTerms)));
// zkHost
expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost));
return expression;
}
private void init(String collectionName,
String zkHost,
Map params,
String field,
String outcome,
String featureSet,
int positiveLabel, int numTopTerms) throws IOException {
this.zkHost = zkHost;
this.collection = collectionName;
this.params = params;
this.field = field;
this.outcome = outcome;
this.featureSet = featureSet;
this.positiveLabel = positiveLabel;
this.numTerms = numTopTerms;
}
public void setStreamContext(StreamContext context) {
this.cache = context.getSolrClientCache();
this.streamContext = context;
}
/**
* Opens the CloudSolrStream
*
***/
public void open() throws IOException {
if (cache == null) {
isCloseCache = true;
cache = new SolrClientCache();
} else {
isCloseCache = false;
}
this.cloudSolrClient = this.cache.getCloudSolrClient(zkHost);
this.executorService = ExecutorUtil.newMDCAwareCachedThreadPool(new SolrjNamedThreadFactory("FeaturesSelectionStream"));
}
public List<TupleStream> children() {
return null;
}
private List<String> getShardUrls() throws IOException {
try {
ZkStateReader zkStateReader = cloudSolrClient.getZkStateReader();
ClusterState clusterState = zkStateReader.getClusterState();
Collection<Slice> slices = clusterState.getActiveSlices(this.collection);
Set<String> liveNodes = clusterState.getLiveNodes();
List<String> baseUrls = new ArrayList<>();
for(Slice slice : slices) {
Collection<Replica> replicas = slice.getReplicas();
List<Replica> shuffler = new ArrayList<>();
for(Replica replica : replicas) {
if(replica.getState() == Replica.State.ACTIVE && liveNodes.contains(replica.getNodeName())) {
shuffler.add(replica);
}
}
Collections.shuffle(shuffler, new Random());
Replica rep = shuffler.get(0);
ZkCoreNodeProps zkProps = new ZkCoreNodeProps(rep);
String url = zkProps.getCoreUrl();
baseUrls.add(url);
}
return baseUrls;
} catch (Exception e) {
throw new IOException(e);
}
}
private List<Future<NamedList>> callShards(List<String> baseUrls) throws IOException {
List<Future<NamedList>> futures = new ArrayList<>();
for (String baseUrl : baseUrls) {
FeaturesSelectionCall lc = new FeaturesSelectionCall(baseUrl,
this.params,
this.field,
this.outcome);
Future<NamedList> future = executorService.submit(lc);
futures.add(future);
}
return futures;
}
public void close() throws IOException {
if (isCloseCache) {
cache.close();
}
executorService.shutdown();
}
/** Return the stream sort - ie, the order in which records are returned */
public StreamComparator getStreamSort(){
return null;
}
@Override
public Explanation toExplanation(StreamFactory factory) throws IOException {
return new StreamExplanation(getStreamNodeId().toString())
.withFunctionName(factory.getFunctionName(this.getClass()))
.withImplementingClass(this.getClass().getName())
.withExpressionType(Explanation.ExpressionType.STREAM_DECORATOR)
.withExpression(toExpression(factory).toString());
}
public Tuple read() throws IOException {
try {
if (tupleIterator == null) {
Map<String, Double> termScores = new HashMap<>();
Map<String, Long> docFreqs = new HashMap<>();
long numDocs = 0;
for (Future<NamedList> getTopTermsCall : callShards(getShardUrls())) {
NamedList resp = getTopTermsCall.get();
NamedList<Double> shardTopTerms = (NamedList<Double>)resp.get("featuredTerms");
NamedList<Integer> shardDocFreqs = (NamedList<Integer>)resp.get("docFreq");
numDocs += (Integer)resp.get("numDocs");
for (int i = 0; i < shardTopTerms.size(); i++) {
String term = shardTopTerms.getName(i);
double score = shardTopTerms.getVal(i);
int docFreq = shardDocFreqs.get(term);
double prevScore = termScores.containsKey(term) ? termScores.get(term) : 0;
long prevDocFreq = docFreqs.containsKey(term) ? docFreqs.get(term) : 0;
termScores.put(term, prevScore + score);
docFreqs.put(term, prevDocFreq + docFreq);
}
}
List<Tuple> tuples = new ArrayList<>(numTerms);
termScores = sortByValue(termScores);
int index = 0;
for (Map.Entry<String, Double> termScore : termScores.entrySet()) {
if (tuples.size() == numTerms) break;
index++;
Map map = new HashMap();
map.put("id", featureSet + "_" + index);
map.put("index_i", index);
map.put("term_s", termScore.getKey());
map.put("score_f", termScore.getValue());
map.put("featureSet_s", featureSet);
long docFreq = docFreqs.get(termScore.getKey());
double d = Math.log(((double)numDocs / (double)(docFreq + 1)));
map.put("idf_d", d);
tuples.add(new Tuple(map));
}
Map map = new HashMap();
map.put("EOF", true);
tuples.add(new Tuple(map));
tupleIterator = tuples.iterator();
}
return tupleIterator.next();
} catch(Exception e) {
throw new IOException(e);
}
}
private <K, V extends Comparable<? super V>> Map<K, V> sortByValue( Map<K, V> map )
{
Map<K, V> result = new LinkedHashMap<>();
Stream<Map.Entry<K, V>> st = map.entrySet().stream();
st.sorted( Map.Entry.comparingByValue(
(c1, c2) -> c2.compareTo(c1)
) ).forEachOrdered( e -> result.put(e.getKey(), e.getValue()) );
return result;
}
protected class FeaturesSelectionCall implements Callable<NamedList> {
private String baseUrl;
private String outcome;
private String field;
private Map<String, String> paramsMap;
public FeaturesSelectionCall(String baseUrl,
Map<String, String> paramsMap,
String field,
String outcome) {
this.baseUrl = baseUrl;
this.outcome = outcome;
this.field = field;
this.paramsMap = paramsMap;
}
public NamedList<Double> call() throws Exception {
ModifiableSolrParams params = new ModifiableSolrParams();
HttpSolrClient solrClient = cache.getHttpSolrClient(baseUrl);
params.add("distrib", "false");
params.add("fq","{!igain}");
for(String key : paramsMap.keySet()) {
params.add(key, paramsMap.get(key));
}
params.add("outcome", outcome);
params.add("positiveLabel", Integer.toString(positiveLabel));
params.add("field", field);
params.add("numTerms", String.valueOf(numTerms));
QueryRequest request= new QueryRequest(params);
QueryResponse response = request.process(solrClient);
NamedList res = response.getResponse();
return res;
}
}
}

View File

@ -0,0 +1,657 @@
package org.apache.solr.client.solrj.io.stream;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.solr.client.solrj.SolrRequest;
import org.apache.solr.client.solrj.SolrServerException;
import org.apache.solr.client.solrj.impl.CloudSolrClient;
import org.apache.solr.client.solrj.impl.HttpSolrClient;
import org.apache.solr.client.solrj.io.ClassificationEvaluation;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.client.solrj.request.QueryRequest;
import org.apache.solr.client.solrj.response.QueryResponse;
import org.apache.solr.common.cloud.ClusterState;
import org.apache.solr.common.cloud.Replica;
import org.apache.solr.common.cloud.Slice;
import org.apache.solr.common.cloud.ZkCoreNodeProps;
import org.apache.solr.common.cloud.ZkStateReader;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.util.ExecutorUtil;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.common.util.SolrjNamedThreadFactory;
public class TextLogitStream extends TupleStream implements Expressible {
private static final long serialVersionUID = 1;
protected String zkHost;
protected String collection;
protected Map<String,String> params;
protected String field;
protected String name;
protected String outcome;
protected int positiveLabel;
protected double threshold;
protected List<Double> weights;
protected int maxIterations;
protected int iteration;
protected double error;
protected List<Double> idfs;
protected ClassificationEvaluation evaluation;
protected transient SolrClientCache cache;
protected transient boolean isCloseCache;
protected transient CloudSolrClient cloudSolrClient;
protected transient StreamContext streamContext;
protected ExecutorService executorService;
protected TupleStream termsStream;
private List<String> terms;
private double learningRate = 0.01;
private double lastError = 0;
public TextLogitStream(String zkHost,
String collectionName,
Map params,
String name,
String field,
TupleStream termsStream,
List<Double> weights,
String outcome,
int positiveLabel,
double threshold,
int maxIterations) throws IOException {
init(collectionName, zkHost, params, name, field, termsStream, weights, outcome, positiveLabel, threshold, maxIterations, iteration);
}
/**
* logit(collection, zkHost="", features="a,b,c,d,e,f,g", outcome="y", maxIteration="20")
**/
public TextLogitStream(StreamExpression expression, StreamFactory factory) throws IOException{
// grab all parameters out
String collectionName = factory.getValueOperand(expression, 0);
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");
List<StreamExpression> streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class);
// Validate there are no unknown parameters - zkHost and alias are namedParameter so we don't need to count it twice
if(expression.getParameters().size() != 1 + namedParams.size() + streamExpressions.size()){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - unknown operands found",expression));
}
// Collection Name
if(null == collectionName){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression));
}
// Named parameters - passed directly to solr as solrparams
if(0 == namedParams.size()){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one named parameter expected. eg. 'q=*:*'",expression));
}
Map<String,String> params = new HashMap<String,String>();
for(StreamExpressionNamedParameter namedParam : namedParams){
if(!namedParam.getName().equals("zkHost")) {
params.put(namedParam.getName(), namedParam.getParameter().toString().trim());
}
}
String name = params.get("name");
if (name != null) {
params.remove("name");
} else {
throw new IOException("name param cannot be null for TextLogitStream");
}
String feature = params.get("field");
if (feature != null) {
params.remove("field");
} else {
throw new IOException("field param cannot be null for TextLogitStream");
}
TupleStream stream = null;
if (streamExpressions.size() > 0) {
stream = factory.constructStream(streamExpressions.get(0));
} else {
throw new IOException("features must be present for TextLogitStream");
}
String maxIterationsParam = params.get("maxIterations");
int maxIterations = 0;
if(maxIterationsParam != null) {
maxIterations = Integer.parseInt(maxIterationsParam);
params.remove("maxIterations");
} else {
throw new IOException("maxIterations param cannot be null for TextLogitStream");
}
String outcomeParam = params.get("outcome");
if(outcomeParam != null) {
params.remove("outcome");
} else {
throw new IOException("outcome param cannot be null for TextLogitStream");
}
String positiveLabelParam = params.get("positiveLabel");
int positiveLabel = 1;
if(positiveLabelParam != null) {
positiveLabel = Integer.parseInt(positiveLabelParam);
params.remove("positiveLabel");
}
String thresholdParam = params.get("threshold");
double threshold = 0.5;
if(thresholdParam != null) {
threshold = Double.parseDouble(thresholdParam);
params.remove("threshold");
}
int iteration = 0;
String iterationParam = params.get("iteration");
if(iterationParam != null) {
iteration = Integer.parseInt(iterationParam);
params.remove("iteration");
}
List<Double> weights = null;
String weightsParam = params.get("weights");
if(weightsParam != null) {
weights = new ArrayList<>();
String[] weightsArray = weightsParam.split(",");
for(String weightString : weightsArray) {
weights.add(Double.parseDouble(weightString));
}
params.remove("weights");
}
// zkHost, optional - if not provided then will look into factory list to get
String zkHost = null;
if(null == zkHostExpression){
zkHost = factory.getCollectionZkHost(collectionName);
}
else if(zkHostExpression.getParameter() instanceof StreamExpressionValue){
zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue();
}
if(null == zkHost){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName));
}
// We've got all the required items
init(collectionName, zkHost, params, name, feature, stream, weights, outcomeParam, positiveLabel, threshold, maxIterations, iteration);
}
@Override
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
return toExpression(factory, true);
}
private StreamExpression toExpression(StreamFactory factory, boolean includeStreams) throws IOException {
// function name
StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
// collection
expression.addParameter(collection);
if (includeStreams && !(termsStream instanceof TermsStream)) {
if (termsStream instanceof Expressible) {
expression.addParameter(((Expressible)termsStream).toExpression(factory));
} else {
throw new IOException("This TextLogitStream contains a non-expressible TupleStream - it cannot be converted to an expression");
}
}
// parameters
for(Entry<String,String> param : params.entrySet()){
expression.addParameter(new StreamExpressionNamedParameter(param.getKey(), param.getValue()));
}
expression.addParameter(new StreamExpressionNamedParameter("field", field));
expression.addParameter(new StreamExpressionNamedParameter("name", name));
if (termsStream instanceof TermsStream) {
loadTerms();
expression.addParameter(new StreamExpressionNamedParameter("terms", toString(terms)));
}
expression.addParameter(new StreamExpressionNamedParameter("outcome", outcome));
if(weights != null) {
expression.addParameter(new StreamExpressionNamedParameter("weights", toString(weights)));
}
expression.addParameter(new StreamExpressionNamedParameter("maxIterations", Integer.toString(maxIterations)));
if(iteration > 0) {
expression.addParameter(new StreamExpressionNamedParameter("iteration", Integer.toString(iteration)));
}
expression.addParameter(new StreamExpressionNamedParameter("positiveLabel", Integer.toString(positiveLabel)));
expression.addParameter(new StreamExpressionNamedParameter("threshold", Double.toString(threshold)));
// zkHost
expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost));
return expression;
}
private void init(String collectionName,
String zkHost,
Map params,
String name,
String feature,
TupleStream termsStream,
List<Double> weights,
String outcome,
int positiveLabel,
double threshold,
int maxIterations,
int iteration) throws IOException {
this.zkHost = zkHost;
this.collection = collectionName;
this.params = params;
this.name = name;
this.field = feature;
this.termsStream = termsStream;
this.outcome = outcome;
this.positiveLabel = positiveLabel;
this.threshold = threshold;
this.weights = weights;
this.maxIterations = maxIterations;
this.iteration = iteration;
}
public void setStreamContext(StreamContext context) {
this.cache = context.getSolrClientCache();
this.streamContext = context;
this.termsStream.setStreamContext(context);
}
/**
* Opens the CloudSolrStream
*
***/
public void open() throws IOException {
if (cache == null) {
isCloseCache = true;
cache = new SolrClientCache();
} else {
isCloseCache = false;
}
this.cloudSolrClient = this.cache.getCloudSolrClient(zkHost);
this.executorService = ExecutorUtil.newMDCAwareCachedThreadPool(new SolrjNamedThreadFactory("TextLogitSolrStream"));
}
public List<TupleStream> children() {
List<TupleStream> l = new ArrayList();
l.add(termsStream);
return l;
}
protected List<String> getShardUrls() throws IOException {
try {
ZkStateReader zkStateReader = cloudSolrClient.getZkStateReader();
ClusterState clusterState = zkStateReader.getClusterState();
Set<String> liveNodes = clusterState.getLiveNodes();
Collection<Slice> slices = clusterState.getActiveSlices(this.collection);
List baseUrls = new ArrayList();
for(Slice slice : slices) {
Collection<Replica> replicas = slice.getReplicas();
List<Replica> shuffler = new ArrayList();
for(Replica replica : replicas) {
if(replica.getState() == Replica.State.ACTIVE && liveNodes.contains(replica.getNodeName())) {
shuffler.add(replica);
}
}
Collections.shuffle(shuffler, new Random());
Replica rep = shuffler.get(0);
ZkCoreNodeProps zkProps = new ZkCoreNodeProps(rep);
String url = zkProps.getCoreUrl();
baseUrls.add(url);
}
return baseUrls;
} catch (Exception e) {
throw new IOException(e);
}
}
private List<Future<Tuple>> callShards(List<String> baseUrls) throws IOException {
List<Future<Tuple>> futures = new ArrayList();
for (String baseUrl : baseUrls) {
LogitCall lc = new LogitCall(baseUrl,
this.params,
this.field,
this.terms,
this.weights,
this.outcome,
this.positiveLabel,
this.learningRate,
this.iteration);
Future<Tuple> future = executorService.submit(lc);
futures.add(future);
}
return futures;
}
public void close() throws IOException {
if (isCloseCache) {
cache.close();
}
executorService.shutdown();
termsStream.close();
}
/** Return the stream sort - ie, the order in which records are returned */
public StreamComparator getStreamSort(){
return null;
}
@Override
public Explanation toExplanation(StreamFactory factory) throws IOException {
StreamExplanation explanation = new StreamExplanation(getStreamNodeId().toString());
explanation.setFunctionName(factory.getFunctionName(this.getClass()));
explanation.setImplementingClass(this.getClass().getName());
explanation.setExpressionType(Explanation.ExpressionType.MACHINE_LEARNING_MODEL);
explanation.setExpression(toExpression(factory).toString());
explanation.addChild(termsStream.toExplanation(factory));
return explanation;
}
public void loadTerms() throws IOException {
if (this.terms == null) {
termsStream.open();
this.terms = new ArrayList<>();
this.idfs = new ArrayList();
while (true) {
Tuple termTuple = termsStream.read();
if (termTuple.EOF) {
break;
} else {
terms.add(termTuple.getString("term_s"));
idfs.add(termTuple.getDouble("idf_d"));
}
}
termsStream.close();
}
}
public Tuple read() throws IOException {
try {
if(++iteration > maxIterations) {
Map map = new HashMap();
map.put("EOF", true);
return new Tuple(map);
} else {
if (this.idfs == null) {
loadTerms();
if (weights != null && terms.size() + 1 != weights.size()) {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - the number of weights must be %d, found %d", terms.size()+1, weights.size()));
}
}
List<List<Double>> allWeights = new ArrayList();
this.evaluation = new ClassificationEvaluation();
this.error = 0;
for (Future<Tuple> logitCall : callShards(getShardUrls())) {
Tuple tuple = logitCall.get();
List<Double> shardWeights = (List<Double>) tuple.get("weights");
allWeights.add(shardWeights);
this.error += tuple.getDouble("error");
Map shardEvaluation = (Map) tuple.get("evaluation");
this.evaluation.addEvaluation(shardEvaluation);
}
this.weights = averageWeights(allWeights);
Map map = new HashMap();
map.put("id", name+"_"+iteration);
map.put("name_s", name);
map.put("field_s", field);
map.put("terms_ss", terms);
map.put("iteration_i", iteration);
if(weights != null) {
map.put("weights_ds", weights);
}
map.put("error_d", error);
evaluation.putToMap(map);
map.put("alpha_d", this.learningRate);
map.put("idfs_ds", this.idfs);
if (iteration != 1) {
if (lastError <= error) {
this.learningRate *= 0.5;
} else {
this.learningRate *= 1.05;
}
}
lastError = error;
return new Tuple(map);
}
} catch(Exception e) {
throw new IOException(e);
}
}
private List<Double> averageWeights(List<List<Double>> allWeights) {
double[] working = new double[allWeights.get(0).size()];
for(List<Double> shardWeights: allWeights) {
for(int i=0; i<working.length; i++) {
working[i] += shardWeights.get(i);
}
}
for(int i=0; i<working.length; i++) {
working[i] = working[i] / allWeights.size();
}
List<Double> ave = new ArrayList();
for(double d : working) {
ave.add(d);
}
return ave;
}
static String toString(List items) {
StringBuilder buf = new StringBuilder();
for(Object item : items) {
if(buf.length() > 0) {
buf.append(",");
}
buf.append(item.toString());
}
return buf.toString();
}
protected class TermsStream extends TupleStream {
private List<String> terms;
private Iterator<String> it;
public TermsStream(List<String> terms) {
this.terms = terms;
}
@Override
public void setStreamContext(StreamContext context) {}
@Override
public List<TupleStream> children() { return new ArrayList<>(); }
@Override
public void open() throws IOException { this.it = this.terms.iterator();}
@Override
public void close() throws IOException {}
@Override
public Tuple read() throws IOException {
HashMap map = new HashMap();
if(it.hasNext()) {
map.put("term_s",it.next());
map.put("score_f",1.0);
return new Tuple(map);
} else {
map.put("EOF", true);
return new Tuple(map);
}
}
@Override
public StreamComparator getStreamSort() {return null;}
@Override
public Explanation toExplanation(StreamFactory factory) throws IOException {
return new StreamExplanation(getStreamNodeId().toString())
.withFunctionName("non-expressible")
.withImplementingClass(this.getClass().getName())
.withExpressionType(Explanation.ExpressionType.STREAM_SOURCE)
.withExpression("non-expressible");
}
}
protected class LogitCall implements Callable<Tuple> {
private String baseUrl;
private String feature;
private List<String> terms;
private List<Double> weights;
private int iteration;
private String outcome;
private int positiveLabel;
private double learningRate;
private Map<String, String> paramsMap;
public LogitCall(String baseUrl,
Map<String, String> paramsMap,
String feature,
List<String> terms,
List<Double> weights,
String outcome,
int positiveLabel,
double learningRate,
int iteration) {
this.baseUrl = baseUrl;
this.feature = feature;
this.terms = terms;
this.weights = weights;
this.iteration = iteration;
this.outcome = outcome;
this.positiveLabel = positiveLabel;
this.learningRate = learningRate;
this.paramsMap = paramsMap;
}
public Tuple call() throws Exception {
ModifiableSolrParams params = new ModifiableSolrParams();
HttpSolrClient solrClient = cache.getHttpSolrClient(baseUrl);
params.add("distrib", "false");
params.add("fq","{!tlogit}");
params.add("feature", feature);
params.add("terms", TextLogitStream.toString(terms));
params.add("idfs", TextLogitStream.toString(idfs));
for(String key : paramsMap.keySet()) {
params.add(key, paramsMap.get(key));
}
if(weights != null) {
params.add("weights", TextLogitStream.toString(weights));
}
params.add("iteration", Integer.toString(iteration));
params.add("outcome", outcome);
params.add("positiveLabel", Integer.toString(positiveLabel));
params.add("threshold", Double.toString(threshold));
params.add("alpha", Double.toString(learningRate));
QueryRequest request= new QueryRequest(params, SolrRequest.METHOD.POST);
QueryResponse response = request.process(solrClient);
NamedList res = response.getResponse();
NamedList logit = (NamedList)res.get("logit");
List<Double> shardWeights = (List<Double>)logit.get("weights");
double shardError = (double)logit.get("error");
Map map = new HashMap();
map.put("error", shardError);
map.put("weights", shardWeights);
map.put("evaluation", logit.get("evaluation"));
return new Tuple(map);
}
}
}

View File

@ -147,6 +147,7 @@ public class Explanation {
public static interface ExpressionType{
public static final String GRAPH_SOURCE = "graph-source";
public static final String MACHINE_LEARNING_MODEL = "ml-model";
public static final String STREAM_SOURCE = "stream-source";
public static final String STREAM_DECORATOR = "stream-decorator";
public static final String DATASTORE = "datastore";

View File

@ -0,0 +1,77 @@
<?xml version="1.0" ?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<!-- The Solr schema file. This file should be named "schema.xml" and
should be located where the classloader for the Solr webapp can find it.
This schema is used for testing, and as such has everything and the
kitchen sink thrown in. See example/solr/conf/schema.xml for a
more concise example.
-->
<schema name="test" version="1.6">
<fieldType name="int" docValues="true" class="solr.TrieIntField" precisionStep="0" omitNorms="true" positionIncrementGap="0"/>
<fieldType name="float" docValues="true" class="solr.TrieFloatField" precisionStep="0" omitNorms="true" positionIncrementGap="0"/>
<fieldType name="long" class="solr.TrieLongField" precisionStep="0" omitNorms="true" positionIncrementGap="0"/>
<fieldType name="double" class="solr.TrieDoubleField" precisionStep="0" omitNorms="true" positionIncrementGap="0"/>
<fieldType name="tint" class="solr.TrieIntField" precisionStep="8" omitNorms="true" positionIncrementGap="0"/>
<fieldType name="tfloat" class="solr.TrieFloatField" precisionStep="8" omitNorms="true" positionIncrementGap="0"/>
<fieldType name="tlong" class="solr.TrieLongField" precisionStep="8" omitNorms="true" positionIncrementGap="0"/>
<fieldType name="tdouble" class="solr.TrieDoubleField" precisionStep="8" omitNorms="true" positionIncrementGap="0"/>
<fieldType name="random" class="solr.RandomSortField" indexed="true" />
<fieldtype name="boolean" class="solr.BoolField" sortMissingLast="true"/>
<fieldtype name="string" class="solr.StrField" sortMissingLast="true" docValues="true"/>
<!-- format for date is 1995-12-31T23:59:59.999Z and only the fractional
seconds part (.999) is optional.
-->
<fieldtype name="date" class="solr.TrieDateField" precisionStep="0"/>
<fieldtype name="tdate" class="solr.TrieDateField" precisionStep="6"/>
<field name="id" type="string" indexed="true" stored="true" multiValued="false" required="false"/>
<field name="_version_" type="long" indexed="true" stored="true"/>
<!-- Dynamic field definitions. If a field name is not found, dynamicFields
will be used if the name matches any of the patterns.
RESTRICTION: the glob-like pattern in the name attribute must have
a "*" only at the start or the end.
EXAMPLE: name="*_i" will match any field ending in _i (like myid_i, z_i)
Longer patterns will be matched first. if equal size patterns
both match, the first appearing in the schema will be used.
-->
<dynamicField name="*_b" type="boolean" indexed="true" stored="true" multiValued="false"/>
<dynamicField name="*_bs" type="boolean" indexed="true" stored="true" multiValued="true"/>
<dynamicField name="*_i" type="int" indexed="true" stored="true" multiValued="false"/>
<dynamicField name="*_is" type="int" indexed="true" stored="true" multiValued="true"/>
<dynamicField name="*_l" type="long" indexed="true" stored="true" multiValued="false"/>
<dynamicField name="*_ls" type="long" indexed="true" stored="true" multiValued="true"/>
<dynamicField name="*_f" type="float" indexed="true" stored="true" multiValued="false"/>
<dynamicField name="*_fs" type="float" indexed="true" stored="true" multiValued="true"/>
<dynamicField name="*_d" type="double" indexed="true" stored="true" multiValued="false"/>
<dynamicField name="*_ds" type="double" indexed="true" stored="true" multiValued="true"/>
<dynamicField name="*_s" type="string" indexed="true" stored="true" multiValued="false"/>
<dynamicField name="*_ss" type="string" indexed="true" stored="true" multiValued="true"/>
<uniqueKey>id</uniqueKey>
</schema>

View File

@ -0,0 +1,51 @@
<?xml version="1.0" encoding="UTF-8" ?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<!--
This is a stripped down config file used for a simple example...
It is *not* a good example to work from.
-->
<config>
<luceneMatchVersion>${tests.luceneMatchVersion:LUCENE_CURRENT}</luceneMatchVersion>
<indexConfig>
<useCompoundFile>${useCompoundFile:false}</useCompoundFile>
</indexConfig>
<dataDir>${solr.data.dir:}</dataDir>
<directoryFactory name="DirectoryFactory" class="${solr.directoryFactory:solr.StandardDirectoryFactory}"/>
<schemaFactory class="ClassicIndexSchemaFactory"/>
<updateHandler class="solr.DirectUpdateHandler2">
<updateLog>
<str name="dir">${solr.data.dir:}</str>
</updateLog>
</updateHandler>
<requestDispatcher handleSelect="true" >
<requestParsers enableRemoteStreaming="false" multipartUploadLimitInKB="2048" />
</requestDispatcher>
<requestHandler name="standard" class="solr.StandardRequestHandler" default="true" />
<!-- config for the admin interface -->
<admin>
<defaultQuery>solr</defaultQuery>
</admin>
</config>

View File

@ -16,16 +16,23 @@
*/
package org.apache.solr.client.solrj.io.stream;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.LuceneTestCase.Slow;
import org.apache.solr.client.solrj.embedded.JettySolrRunner;
import org.apache.solr.client.solrj.io.ClassificationEvaluation;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.ComparatorOrder;
@ -71,6 +78,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
public static void setupCluster() throws Exception {
configureCluster(4)
.addConfig("conf", getFile("solrj").toPath().resolve("solr").resolve("configsets").resolve("streaming").resolve("conf"))
.addConfig("ml", getFile("solrj").toPath().resolve("solr").resolve("configsets").resolve("ml").resolve("conf"))
.configure();
CollectionAdminRequest.createCollection(COLLECTION, "conf", 2, 1).process(cluster.getSolrClient());
@ -2773,6 +2781,8 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assert(tuple.getDouble("a_f") == 4.0);
assertList(tuple.getStrings("s_multi"), "aaaa3", "bbbb3");
assertList(tuple.getLongs("i_multi"), Long.parseLong("4444"), Long.parseLong("7777"));
CollectionAdminRequest.deleteCollection("destinationCollection").process(cluster.getSolrClient());
}
@Test
@ -2863,6 +2873,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertList(tuple.getStrings("s_multi"), "aaaa3", "bbbb3");
assertList(tuple.getLongs("i_multi"), Long.parseLong("4444"), Long.parseLong("7777"));
CollectionAdminRequest.deleteCollection("parallelDestinationCollection").process(cluster.getSolrClient());
}
@Test
@ -3025,6 +3036,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertList(tuple.getStrings("s_multi"), "aaaa3", "bbbb3");
assertList(tuple.getLongs("i_multi"), Long.parseLong("4444"), Long.parseLong("7777"));
CollectionAdminRequest.deleteCollection("parallelDestinationCollection1").process(cluster.getSolrClient());
}
@Test
@ -3065,6 +3077,121 @@ public class StreamExpressionTest extends SolrCloudTestCase {
}
@Test
public void testBasicTextLogitStream() throws Exception {
CollectionAdminRequest.createCollection("destinationCollection", "ml", 2, 1).process(cluster.getSolrClient());
AbstractDistribZkTestBase.waitForRecoveriesToFinish("destinationCollection", cluster.getSolrClient().getZkStateReader(),
false, true, TIMEOUT);
UpdateRequest updateRequest = new UpdateRequest();
for (int i = 0; i < 5000; i+=2) {
updateRequest.add(id, String.valueOf(i), "tv_text", "a b c c d", "out_i", "1");
updateRequest.add(id, String.valueOf(i+1), "tv_text", "a b e e f", "out_i", "0");
}
updateRequest.commit(cluster.getSolrClient(), COLLECTION);
StreamExpression expression;
TupleStream stream;
List<Tuple> tuples;
StreamFactory factory = new StreamFactory()
.withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress())
.withCollectionZkHost("destinationCollection", cluster.getZkServer().getZkAddress())
.withFunctionName("features", FeaturesSelectionStream.class)
.withFunctionName("train", TextLogitStream.class)
.withFunctionName("search", CloudSolrStream.class)
.withFunctionName("update", UpdateStream.class);
expression = StreamExpressionParser.parse("features(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4)");
stream = new FeaturesSelectionStream(expression, factory);
tuples = getTuples(stream);
assert(tuples.size() == 4);
HashSet<String> terms = new HashSet<>();
for (Tuple tuple : tuples) {
terms.add((String) tuple.get("term_s"));
}
assertTrue(terms.contains("d"));
assertTrue(terms.contains("c"));
assertTrue(terms.contains("e"));
assertTrue(terms.contains("f"));
String textLogitExpression = "train(" +
"collection1, " +
"features(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4),"+
"q=\"*:*\", " +
"name=\"model\", " +
"field=\"tv_text\", " +
"outcome=\"out_i\", " +
"maxIterations=100)";
stream = factory.constructStream(textLogitExpression);
tuples = getTuples(stream);
Tuple lastTuple = tuples.get(tuples.size() - 1);
List<Double> lastWeights = lastTuple.getDoubles("weights_ds");
Double[] lastWeightsArray = lastWeights.toArray(new Double[lastWeights.size()]);
// first feature is bias value
Double[] testRecord = {1.0, 1.17, 0.691, 0.0, 0.0};
double d = sum(multiply(testRecord, lastWeightsArray));
double prob = sigmoid(d);
assertEquals(prob, 1.0, 0.1);
// first feature is bias value
Double[] testRecord2 = {1.0, 0.0, 0.0, 1.17, 0.691};
d = sum(multiply(testRecord2, lastWeightsArray));
prob = sigmoid(d);
assertEquals(prob, 0, 0.1);
stream = factory.constructStream("update(destinationCollection, batchSize=5, "+textLogitExpression+")");
getTuples(stream);
cluster.getSolrClient().commit("destinationCollection");
stream = factory.constructStream("search(destinationCollection, " +
"q=*:*, " +
"fl=\"iteration_i,* \", " +
"rows=100, " +
"sort=\"iteration_i desc\")");
tuples = getTuples(stream);
assertEquals(100, tuples.size());
Tuple lastModel = tuples.get(0);
ClassificationEvaluation evaluation = ClassificationEvaluation.create(lastModel.fields);
assertTrue(evaluation.getF1() >= 1.0);
assertEquals(Math.log( 5000.0 / (2500 + 1)), lastModel.getDoubles("idfs_ds").get(0), 0.0001);
// make sure the tuples is retrieved in correct order
Tuple firstTuple = tuples.get(99);
assertEquals(1L, (long) firstTuple.getLong("iteration_i"));
CollectionAdminRequest.deleteCollection("destinationCollection").process(cluster.getSolrClient());
}
private double sigmoid(double in) {
double d = 1.0 / (1+Math.exp(-in));
return d;
}
private double[] multiply(Double[] vec1, Double[] vec2) {
double[] working = new double[vec1.length];
for(int i=0; i<vec1.length; i++) {
working[i]= vec1[i]*vec2[i];
}
return working;
}
private double sum(double[] vec) {
double d = 0.0;
for(double v : vec) {
d += v;
}
return d;
}
@Test
public void testParallelIntersectStream() throws Exception {
@ -3103,6 +3230,62 @@ public class StreamExpressionTest extends SolrCloudTestCase {
}
@Test
public void testFeaturesSelectionStream() throws Exception {
CollectionAdminRequest.createCollection("destinationCollection", "ml", 2, 1).process(cluster.getSolrClient());
AbstractDistribZkTestBase.waitForRecoveriesToFinish("destinationCollection", cluster.getSolrClient().getZkStateReader(),
false, true, TIMEOUT);
UpdateRequest updateRequest = new UpdateRequest();
for (int i = 0; i < 5000; i+=2) {
updateRequest.add(id, String.valueOf(i), "whitetok", "a b c d", "out_i", "1");
updateRequest.add(id, String.valueOf(i+1), "whitetok", "a b e f", "out_i", "0");
}
updateRequest.commit(cluster.getSolrClient(), COLLECTION);
StreamExpression expression;
TupleStream stream;
List<Tuple> tuples;
StreamFactory factory = new StreamFactory()
.withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress())
.withCollectionZkHost("destinationCollection", cluster.getZkServer().getZkAddress())
.withFunctionName("featuresSelection", FeaturesSelectionStream.class)
.withFunctionName("search", CloudSolrStream.class)
.withFunctionName("update", UpdateStream.class);
String featuresExpression = "featuresSelection(collection1, q=\"*:*\", featureSet=\"first\", field=\"whitetok\", outcome=\"out_i\", numTerms=4)";
// basic
expression = StreamExpressionParser.parse(featuresExpression);
stream = new FeaturesSelectionStream(expression, factory);
tuples = getTuples(stream);
assert(tuples.size() == 4);
assertTrue(tuples.get(0).get("term_s").equals("c"));
assertTrue(tuples.get(1).get("term_s").equals("d"));
assertTrue(tuples.get(2).get("term_s").equals("e"));
assertTrue(tuples.get(3).get("term_s").equals("f"));
// update
expression = StreamExpressionParser.parse("update(destinationCollection, batchSize=5, "+featuresExpression+")");
stream = new UpdateStream(expression, factory);
getTuples(stream);
cluster.getSolrClient().commit("destinationCollection");
expression = StreamExpressionParser.parse("search(destinationCollection, q=featureSet_s:first, fl=\"index_i, term_s\", sort=\"index_i asc\")");
stream = new CloudSolrStream(expression, factory);
tuples = getTuples(stream);
assertEquals(4, tuples.size());
assertTrue(tuples.get(0).get("term_s").equals("c"));
assertTrue(tuples.get(1).get("term_s").equals("d"));
assertTrue(tuples.get(2).get("term_s").equals("e"));
assertTrue(tuples.get(3).get("term_s").equals("f"));
CollectionAdminRequest.deleteCollection("destinationCollection").process(cluster.getSolrClient());
}
@Test
public void testComplementStream() throws Exception {

View File

@ -62,6 +62,8 @@ public class StreamExpressionToExpessionTest extends LuceneTestCase {
.withFunctionName("avg", MeanMetric.class)
.withFunctionName("daemon", DaemonStream.class)
.withFunctionName("topic", TopicStream.class)
.withFunctionName("tlogit", TextLogitStream.class)
.withFunctionName("featuresSelection", FeaturesSelectionStream.class)
;
}
@ -138,7 +140,6 @@ public class StreamExpressionToExpessionTest extends LuceneTestCase {
assertTrue(expressionString.contains("checkpointEvery=1000"));
}
@Test
public void testStatsStream() throws Exception {
@ -343,6 +344,40 @@ public class StreamExpressionToExpessionTest extends LuceneTestCase {
assertTrue(secondExpressionString.contains("q=\"presentTitles:\\\"chief, executive officer\\\" AND age:[36 TO *]\""));
}
@Test
public void testFeaturesSelectionStream() throws Exception {
String expr = "featuresSelection(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4, positiveLabel=2)";
FeaturesSelectionStream stream = new FeaturesSelectionStream(StreamExpressionParser.parse(expr), factory);
String expressionString = stream.toExpression(factory).toString();
assertTrue(expressionString.contains("q=\"*:*\""));
assertTrue(expressionString.contains("featureSet=first"));
assertTrue(expressionString.contains("field=tv_text"));
assertTrue(expressionString.contains("outcome=out_i"));
assertTrue(expressionString.contains("numTerms=4"));
assertTrue(expressionString.contains("positiveLabel=2"));
}
@Test
public void testTextLogitStreamWithFeaturesSelection() throws Exception {
String expr = "tlogit(" +
"collection1, " +
"q=\"*:*\", " +
"name=\"model\", " +
"featuresSelection(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4), " +
"field=\"tv_text\", " +
"outcome=\"out_i\", " +
"maxIterations=100)";
TextLogitStream logitStream = new TextLogitStream(StreamExpressionParser.parse(expr), factory);
String expressionString = logitStream.toExpression(factory).toString();
assertTrue(expressionString.contains("q=\"*:*\""));
assertTrue(expressionString.contains("name=model"));
assertFalse(expressionString.contains("terms="));
assertTrue(expressionString.contains("featuresSelection("));
assertTrue(expressionString.contains("field=tv_text"));
assertTrue(expressionString.contains("outcome=out_i"));
assertTrue(expressionString.contains("maxIterations=100"));
}
@Test
public void testCountMetric() throws Exception {

View File

@ -17,7 +17,6 @@
package org.apache.solr.client.solrj.io.stream;
import junit.framework.Assert;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.solr.client.solrj.io.ops.GroupOperation;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;