mirror of https://github.com/apache/lucene.git
SOLR-9252: Feature selection and logistic regression on text
This commit is contained in:
parent
9fc4624853
commit
87938e00e9
|
@ -122,6 +122,8 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
|
|||
.withFunctionName("intersect", IntersectStream.class)
|
||||
.withFunctionName("complement", ComplementStream.class)
|
||||
.withFunctionName("sort", SortStream.class)
|
||||
.withFunctionName("train", TextLogitStream.class)
|
||||
.withFunctionName("features", FeaturesSelectionStream.class)
|
||||
.withFunctionName("daemon", DaemonStream.class)
|
||||
.withFunctionName("shortestPath", ShortestPathStream.class)
|
||||
.withFunctionName("gatherNodes", GatherNodesStream.class)
|
||||
|
|
|
@ -0,0 +1,239 @@
|
|||
package org.apache.solr.search;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.TreeSet;
|
||||
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.index.NumericDocValues;
|
||||
import org.apache.lucene.index.PostingsEnum;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.TermsEnum;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.SparseFixedBitSet;
|
||||
import org.apache.solr.common.params.SolrParams;
|
||||
import org.apache.solr.common.util.NamedList;
|
||||
import org.apache.solr.handler.component.ResponseBuilder;
|
||||
import org.apache.solr.request.SolrQueryRequest;
|
||||
|
||||
public class IGainTermsQParserPlugin extends QParserPlugin {
|
||||
|
||||
public static final String NAME = "igain";
|
||||
|
||||
@Override
|
||||
public QParser createParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
|
||||
return new IGainTermsQParser(qstr, localParams, params, req);
|
||||
}
|
||||
|
||||
private static class IGainTermsQParser extends QParser {
|
||||
|
||||
public IGainTermsQParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
|
||||
super(qstr, localParams, params, req);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Query parse() throws SyntaxError {
|
||||
|
||||
String field = getParam("field");
|
||||
String outcome = getParam("outcome");
|
||||
int numTerms = Integer.parseInt(getParam("numTerms"));
|
||||
int positiveLabel = Integer.parseInt(getParam("positiveLabel"));
|
||||
|
||||
return new IGainTermsQuery(field, outcome, positiveLabel, numTerms);
|
||||
}
|
||||
}
|
||||
|
||||
private static class IGainTermsQuery extends AnalyticsQuery {
|
||||
|
||||
private String field;
|
||||
private String outcome;
|
||||
private int numTerms;
|
||||
private int positiveLabel;
|
||||
|
||||
public IGainTermsQuery(String field, String outcome, int positiveLabel, int numTerms) {
|
||||
this.field = field;
|
||||
this.outcome = outcome;
|
||||
this.numTerms = numTerms;
|
||||
this.positiveLabel = positiveLabel;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DelegatingCollector getAnalyticsCollector(ResponseBuilder rb, IndexSearcher searcher) {
|
||||
return new IGainTermsCollector(rb, searcher, field, outcome, positiveLabel, numTerms);
|
||||
}
|
||||
}
|
||||
|
||||
private static class IGainTermsCollector extends DelegatingCollector {
|
||||
|
||||
private String field;
|
||||
private String outcome;
|
||||
private IndexSearcher searcher;
|
||||
private ResponseBuilder rb;
|
||||
private int positiveLabel;
|
||||
private int numTerms;
|
||||
private int count;
|
||||
|
||||
private NumericDocValues leafOutcomeValue;
|
||||
private SparseFixedBitSet positiveSet;
|
||||
private SparseFixedBitSet negativeSet;
|
||||
|
||||
|
||||
private int numPositiveDocs;
|
||||
|
||||
|
||||
public IGainTermsCollector(ResponseBuilder rb, IndexSearcher searcher, String field, String outcome, int positiveLabel, int numTerms) {
|
||||
this.rb = rb;
|
||||
this.searcher = searcher;
|
||||
this.field = field;
|
||||
this.outcome = outcome;
|
||||
this.positiveSet = new SparseFixedBitSet(searcher.getIndexReader().maxDoc());
|
||||
this.negativeSet = new SparseFixedBitSet(searcher.getIndexReader().maxDoc());
|
||||
|
||||
this.numTerms = numTerms;
|
||||
this.positiveLabel = positiveLabel;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doSetNextReader(LeafReaderContext context) throws IOException {
|
||||
super.doSetNextReader(context);
|
||||
LeafReader reader = context.reader();
|
||||
leafOutcomeValue = reader.getNumericDocValues(outcome);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void collect(int doc) throws IOException {
|
||||
super.collect(doc);
|
||||
++count;
|
||||
if (leafOutcomeValue.get(doc) == positiveLabel) {
|
||||
positiveSet.set(context.docBase + doc);
|
||||
numPositiveDocs++;
|
||||
} else {
|
||||
negativeSet.set(context.docBase + doc);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() throws IOException {
|
||||
NamedList<Double> analytics = new NamedList<Double>();
|
||||
NamedList<Integer> topFreq = new NamedList();
|
||||
|
||||
NamedList<Integer> allFreq = new NamedList();
|
||||
|
||||
rb.rsp.add("featuredTerms", analytics);
|
||||
rb.rsp.add("docFreq", topFreq);
|
||||
rb.rsp.add("numDocs", count);
|
||||
|
||||
TreeSet<TermWithScore> topTerms = new TreeSet<>();
|
||||
|
||||
double numDocs = count;
|
||||
double pc = numPositiveDocs / numDocs;
|
||||
double entropyC = binaryEntropy(pc);
|
||||
|
||||
Terms terms = MultiFields.getFields(searcher.getIndexReader()).terms(field);
|
||||
TermsEnum termsEnum = terms.iterator();
|
||||
BytesRef term;
|
||||
PostingsEnum postingsEnum = null;
|
||||
while ((term = termsEnum.next()) != null) {
|
||||
postingsEnum = termsEnum.postings(postingsEnum);
|
||||
int xc = 0;
|
||||
int nc = 0;
|
||||
while (postingsEnum.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
if (positiveSet.get(postingsEnum.docID())) {
|
||||
xc++;
|
||||
} else if (negativeSet.get(postingsEnum.docID())) {
|
||||
nc++;
|
||||
}
|
||||
}
|
||||
|
||||
int docFreq = xc+nc;
|
||||
|
||||
double entropyContainsTerm = binaryEntropy( (double) xc / docFreq );
|
||||
double entropyNotContainsTerm = binaryEntropy( (double) (numPositiveDocs - xc) / (numDocs - docFreq + 1) );
|
||||
double score = entropyC - ( (docFreq / numDocs) * entropyContainsTerm + (1.0 - docFreq / numDocs) * entropyNotContainsTerm);
|
||||
|
||||
topFreq.add(term.utf8ToString(), docFreq);
|
||||
if (topTerms.size() < numTerms) {
|
||||
topTerms.add(new TermWithScore(term.utf8ToString(), score));
|
||||
} else {
|
||||
if (topTerms.first().score < score) {
|
||||
topTerms.pollFirst();
|
||||
topTerms.add(new TermWithScore(term.utf8ToString(), score));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (TermWithScore topTerm : topTerms) {
|
||||
analytics.add(topTerm.term, topTerm.score);
|
||||
topFreq.add(topTerm.term, allFreq.get(topTerm.term));
|
||||
}
|
||||
|
||||
if (this.delegate instanceof DelegatingCollector) {
|
||||
((DelegatingCollector) this.delegate).finish();
|
||||
}
|
||||
}
|
||||
|
||||
private double binaryEntropy(double prob) {
|
||||
if (prob == 0 || prob == 1) return 0;
|
||||
return (-1 * prob * Math.log(prob)) + (-1 * (1.0 - prob) * Math.log(1.0 - prob));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
private static class TermWithScore implements Comparable<TermWithScore>{
|
||||
public final String term;
|
||||
public final double score;
|
||||
|
||||
public TermWithScore(String term, double score) {
|
||||
this.term = term;
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return term.hashCode();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (obj == null) return false;
|
||||
if (obj.getClass() != getClass()) return false;
|
||||
TermWithScore other = (TermWithScore) obj;
|
||||
return other.term.equals(this.term);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(TermWithScore o) {
|
||||
int cmp = Double.compare(this.score, o.score);
|
||||
if (cmp == 0) {
|
||||
return this.term.compareTo(o.term);
|
||||
} else {
|
||||
return cmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -16,6 +16,11 @@
|
|||
*/
|
||||
package org.apache.solr.search;
|
||||
|
||||
import java.net.URL;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.solr.common.params.SolrParams;
|
||||
import org.apache.solr.common.util.NamedList;
|
||||
import org.apache.solr.core.SolrInfoMBean;
|
||||
|
@ -26,11 +31,6 @@ import org.apache.solr.search.join.GraphQParserPlugin;
|
|||
import org.apache.solr.search.mlt.MLTQParserPlugin;
|
||||
import org.apache.solr.util.plugin.NamedListInitializedPlugin;
|
||||
|
||||
import java.net.URL;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public abstract class QParserPlugin implements NamedListInitializedPlugin, SolrInfoMBean {
|
||||
/** internal use - name of the default parser */
|
||||
public static final String DEFAULT_QTYPE = LuceneQParserPlugin.NAME;
|
||||
|
@ -77,6 +77,8 @@ public abstract class QParserPlugin implements NamedListInitializedPlugin, SolrI
|
|||
map.put(GraphQParserPlugin.NAME, GraphQParserPlugin.class);
|
||||
map.put(XmlQParserPlugin.NAME, XmlQParserPlugin.class);
|
||||
map.put(GraphTermsQParserPlugin.NAME, GraphTermsQParserPlugin.class);
|
||||
map.put(IGainTermsQParserPlugin.NAME, IGainTermsQParserPlugin.class);
|
||||
map.put(TextLogisticRegressionQParserPlugin.NAME, TextLogisticRegressionQParserPlugin.class);
|
||||
standardPlugins = Collections.unmodifiableMap(map);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,283 @@
|
|||
package org.apache.solr.search;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.MultiFields;
|
||||
import org.apache.lucene.index.NumericDocValues;
|
||||
import org.apache.lucene.index.PostingsEnum;
|
||||
import org.apache.lucene.index.Terms;
|
||||
import org.apache.lucene.index.TermsEnum;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.SparseFixedBitSet;
|
||||
import org.apache.solr.client.solrj.io.ClassificationEvaluation;
|
||||
import org.apache.solr.common.params.SolrParams;
|
||||
import org.apache.solr.common.util.NamedList;
|
||||
import org.apache.solr.handler.component.ResponseBuilder;
|
||||
import org.apache.solr.request.SolrQueryRequest;
|
||||
|
||||
/**
|
||||
* Returns an AnalyticsQuery implementation that performs
|
||||
* one Gradient Descent iteration of a result set to train a
|
||||
* logistic regression model
|
||||
*
|
||||
* The TextLogitStream provides the parallel iterative framework for this class.
|
||||
**/
|
||||
|
||||
public class TextLogisticRegressionQParserPlugin extends QParserPlugin {
|
||||
public static final String NAME = "tlogit";
|
||||
|
||||
@Override
|
||||
public void init(NamedList args) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public QParser createParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
|
||||
return new TextLogisticRegressionQParser(qstr, localParams, params, req);
|
||||
}
|
||||
|
||||
private static class TextLogisticRegressionQParser extends QParser{
|
||||
|
||||
TextLogisticRegressionQParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
|
||||
super(qstr, localParams, params, req);
|
||||
}
|
||||
|
||||
public Query parse() {
|
||||
|
||||
String fs = params.get("feature");
|
||||
String[] terms = params.get("terms").split(",");
|
||||
String ws = params.get("weights");
|
||||
String dfsStr = params.get("idfs");
|
||||
int iteration = params.getInt("iteration");
|
||||
String outcome = params.get("outcome");
|
||||
int positiveLabel = params.getInt("positiveLabel", 1);
|
||||
double threshold = params.getDouble("threshold", 0.5);
|
||||
double alpha = params.getDouble("alpha", 0.01);
|
||||
|
||||
double[] idfs = new double[terms.length];
|
||||
String[] idfsArr = dfsStr.split(",");
|
||||
for (int i = 0; i < idfsArr.length; i++) {
|
||||
idfs[i] = Double.parseDouble(idfsArr[i]);
|
||||
}
|
||||
|
||||
double[] weights = new double[terms.length+1];
|
||||
|
||||
if(ws != null) {
|
||||
String[] wa = ws.split(",");
|
||||
for (int i = 0; i < wa.length; i++) {
|
||||
weights[i] = Double.parseDouble(wa[i]);
|
||||
}
|
||||
} else {
|
||||
for(int i=0; i<weights.length; i++) {
|
||||
weights[i]= 1.0d;
|
||||
}
|
||||
}
|
||||
|
||||
TrainingParams input = new TrainingParams(fs, terms, idfs, outcome, weights, iteration, alpha, positiveLabel, threshold);
|
||||
|
||||
return new TextLogisticRegressionQuery(input);
|
||||
}
|
||||
}
|
||||
|
||||
private static class TextLogisticRegressionQuery extends AnalyticsQuery {
|
||||
private TrainingParams trainingParams;
|
||||
|
||||
public TextLogisticRegressionQuery(TrainingParams trainingParams) {
|
||||
this.trainingParams = trainingParams;
|
||||
}
|
||||
|
||||
public DelegatingCollector getAnalyticsCollector(ResponseBuilder rbsp, IndexSearcher indexSearcher) {
|
||||
return new TextLogisticRegressionCollector(rbsp, indexSearcher, trainingParams);
|
||||
}
|
||||
}
|
||||
|
||||
private static class TextLogisticRegressionCollector extends DelegatingCollector {
|
||||
private TrainingParams trainingParams;
|
||||
private LeafReader leafReader;
|
||||
|
||||
private double[] workingDeltas;
|
||||
private ClassificationEvaluation classificationEvaluation;
|
||||
private double[] weights;
|
||||
|
||||
private ResponseBuilder rbsp;
|
||||
private NumericDocValues leafOutcomeValue;
|
||||
private double totalError;
|
||||
private SparseFixedBitSet positiveDocsSet;
|
||||
private SparseFixedBitSet docsSet;
|
||||
private IndexSearcher searcher;
|
||||
|
||||
TextLogisticRegressionCollector(ResponseBuilder rbsp, IndexSearcher searcher,
|
||||
TrainingParams trainingParams) {
|
||||
this.trainingParams = trainingParams;
|
||||
this.workingDeltas = new double[trainingParams.weights.length];
|
||||
this.weights = Arrays.copyOf(trainingParams.weights, trainingParams.weights.length);
|
||||
this.rbsp = rbsp;
|
||||
this.classificationEvaluation = new ClassificationEvaluation();
|
||||
this.searcher = searcher;
|
||||
positiveDocsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs());
|
||||
docsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs());
|
||||
}
|
||||
|
||||
public void doSetNextReader(LeafReaderContext context) throws IOException {
|
||||
super.doSetNextReader(context);
|
||||
leafReader = context.reader();
|
||||
leafOutcomeValue = leafReader.getNumericDocValues(trainingParams.outcome);
|
||||
}
|
||||
|
||||
public void collect(int doc) throws IOException{
|
||||
|
||||
int outcome = (int) leafOutcomeValue.get(doc);
|
||||
outcome = trainingParams.positiveLabel == outcome? 1 : 0;
|
||||
if (outcome == 1) {
|
||||
positiveDocsSet.set(context.docBase + doc);
|
||||
}
|
||||
docsSet.set(context.docBase+doc);
|
||||
|
||||
}
|
||||
|
||||
public void finish() throws IOException {
|
||||
|
||||
Map<Integer, double[]> docVectors = new HashMap<>();
|
||||
Terms terms = MultiFields.getFields(searcher.getIndexReader()).terms(trainingParams.feature);
|
||||
TermsEnum termsEnum = terms.iterator();
|
||||
PostingsEnum postingsEnum = null;
|
||||
int termIndex = 0;
|
||||
for (String termStr : trainingParams.terms) {
|
||||
BytesRef term = new BytesRef(termStr);
|
||||
if (termsEnum.seekExact(term)) {
|
||||
postingsEnum = termsEnum.postings(postingsEnum);
|
||||
while (postingsEnum.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
int docId = postingsEnum.docID();
|
||||
if (docsSet.get(docId)) {
|
||||
double[] vector = docVectors.get(docId);
|
||||
if (vector == null) {
|
||||
vector = new double[trainingParams.terms.length+1];
|
||||
vector[0] = 1.0;
|
||||
docVectors.put(docId, vector);
|
||||
}
|
||||
vector[termIndex + 1] = trainingParams.idfs[termIndex] * (1.0 + Math.log(postingsEnum.freq()));
|
||||
}
|
||||
}
|
||||
}
|
||||
termIndex++;
|
||||
}
|
||||
|
||||
for (Map.Entry<Integer, double[]> entry : docVectors.entrySet()) {
|
||||
double[] vector = entry.getValue();
|
||||
int outcome = 0;
|
||||
if (positiveDocsSet.get(entry.getKey())) {
|
||||
outcome = 1;
|
||||
}
|
||||
double sig = sigmoid(sum(multiply(vector, weights)));
|
||||
double error = sig - outcome;
|
||||
double lastSig = sigmoid(sum(multiply(vector, trainingParams.weights)));
|
||||
totalError += Math.abs(lastSig - outcome);
|
||||
classificationEvaluation.count(outcome, lastSig >= trainingParams.threshold ? 1 : 0);
|
||||
|
||||
workingDeltas = multiply(error * trainingParams.alpha, vector);
|
||||
|
||||
for(int i = 0; i< workingDeltas.length; i++) {
|
||||
weights[i] -= workingDeltas[i];
|
||||
}
|
||||
}
|
||||
|
||||
NamedList analytics = new NamedList();
|
||||
rbsp.rsp.add("logit", analytics);
|
||||
|
||||
List<Double> outWeights = new ArrayList<>();
|
||||
for(Double d : weights) {
|
||||
outWeights.add(d);
|
||||
}
|
||||
|
||||
analytics.add("weights", outWeights);
|
||||
analytics.add("error", totalError);
|
||||
analytics.add("evaluation", classificationEvaluation.toMap());
|
||||
analytics.add("feature", trainingParams.feature);
|
||||
analytics.add("positiveLabel", trainingParams.positiveLabel);
|
||||
if(this.delegate instanceof DelegatingCollector) {
|
||||
((DelegatingCollector)this.delegate).finish();
|
||||
}
|
||||
}
|
||||
|
||||
private double sigmoid(double in) {
|
||||
double d = 1.0 / (1+Math.exp(-in));
|
||||
return d;
|
||||
}
|
||||
|
||||
private double[] multiply(double[] vals, double[] weights) {
|
||||
for(int i = 0; i < vals.length; ++i) {
|
||||
workingDeltas[i] = vals[i] * weights[i];
|
||||
}
|
||||
|
||||
return workingDeltas;
|
||||
}
|
||||
|
||||
private double[] multiply(double d, double[] vals) {
|
||||
for(int i = 0; i<vals.length; ++i) {
|
||||
workingDeltas[i] = vals[i] * d;
|
||||
}
|
||||
|
||||
return workingDeltas;
|
||||
}
|
||||
|
||||
private double sum(double[] vals) {
|
||||
double d = 0.0d;
|
||||
for(double val : vals) {
|
||||
d += val;
|
||||
}
|
||||
|
||||
return d;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static class TrainingParams {
|
||||
public final String feature;
|
||||
public final String[] terms;
|
||||
public final double[] idfs;
|
||||
public final String outcome;
|
||||
public final double[] weights;
|
||||
public final int interation;
|
||||
public final int positiveLabel;
|
||||
public final double threshold;
|
||||
public final double alpha;
|
||||
|
||||
public TrainingParams(String feature, String[] terms, double[] idfs, String outcome, double[] weights, int interation, double alpha, int positiveLabel, double threshold) {
|
||||
this.feature = feature;
|
||||
this.terms = terms;
|
||||
this.idfs = idfs;
|
||||
this.outcome = outcome;
|
||||
this.weights = weights;
|
||||
this.alpha = alpha;
|
||||
this.interation = interation;
|
||||
this.positiveLabel = positiveLabel;
|
||||
this.threshold = threshold;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -175,6 +175,24 @@ public class QueryEqualityTest extends SolrTestCaseJ4 {
|
|||
}
|
||||
}
|
||||
|
||||
public void testTlogitQuery() throws Exception {
|
||||
SolrQueryRequest req = req("q", "*:*", "feature", "f", "terms","a,b,c", "weights", "100,200,300", "idfs","1,5,7","iteration","1", "outcome","a","positiveLabel","1");
|
||||
try {
|
||||
assertQueryEquals("tlogit", req, "{!tlogit}");
|
||||
} finally {
|
||||
req.close();
|
||||
}
|
||||
}
|
||||
|
||||
public void testIGainQuery() throws Exception {
|
||||
SolrQueryRequest req = req("q", "*:*", "outcome", "b", "positiveLabel", "1", "field", "x", "numTerms","200");
|
||||
try {
|
||||
assertQueryEquals("igain", req, "{!igain}");
|
||||
} finally {
|
||||
req.close();
|
||||
}
|
||||
}
|
||||
|
||||
public void testQuerySwitch() throws Exception {
|
||||
SolrQueryRequest req = req("myXXX", "XXX",
|
||||
"myField", "foo_s",
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
package org.apache.solr.client.solrj.io;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class ClassificationEvaluation {
|
||||
private long truePositive;
|
||||
private long falsePositive;
|
||||
private long trueNegative;
|
||||
private long falseNegative;
|
||||
|
||||
public void count(int actual, int predicted) {
|
||||
if (predicted == 1) {
|
||||
if (actual == 1) truePositive++;
|
||||
else falsePositive++;
|
||||
} else {
|
||||
if (actual == 0) trueNegative++;
|
||||
else falseNegative++;
|
||||
}
|
||||
}
|
||||
|
||||
public void putToMap(Map map) {
|
||||
map.put("truePositive_i",truePositive);
|
||||
map.put("trueNegative_i",trueNegative);
|
||||
map.put("falsePositive_i",falsePositive);
|
||||
map.put("falseNegative_i",falseNegative);
|
||||
}
|
||||
|
||||
public Map toMap() {
|
||||
HashMap map = new HashMap();
|
||||
putToMap(map);
|
||||
return map;
|
||||
}
|
||||
|
||||
public static ClassificationEvaluation create(Map map) {
|
||||
ClassificationEvaluation evaluation = new ClassificationEvaluation();
|
||||
evaluation.addEvaluation(map);
|
||||
return evaluation;
|
||||
}
|
||||
|
||||
public void addEvaluation(Map map) {
|
||||
this.truePositive += (long) map.get("truePositive_i");
|
||||
this.trueNegative += (long) map.get("trueNegative_i");
|
||||
this.falsePositive += (long) map.get("falsePositive_i");
|
||||
this.falseNegative += (long) map.get("falseNegative_i");
|
||||
}
|
||||
|
||||
public double getPrecision() {
|
||||
if (truePositive + falsePositive == 0) return 0;
|
||||
return (double) truePositive / (truePositive + falsePositive);
|
||||
}
|
||||
|
||||
public double getRecall() {
|
||||
if (truePositive + falseNegative == 0) return 0;
|
||||
return (double) truePositive / (truePositive + falseNegative);
|
||||
}
|
||||
|
||||
public double getF1() {
|
||||
double precision = getPrecision();
|
||||
double recall = getRecall();
|
||||
if (precision + recall == 0) return 0;
|
||||
return 2 * (precision * recall) / (precision + recall);
|
||||
}
|
||||
|
||||
public double getAccuracy() {
|
||||
return (double) (truePositive + trueNegative) / (truePositive + trueNegative + falseNegative + falsePositive);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,436 @@
|
|||
package org.apache.solr.client.solrj.io.stream;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import org.apache.solr.client.solrj.impl.CloudSolrClient;
|
||||
import org.apache.solr.client.solrj.impl.HttpSolrClient;
|
||||
import org.apache.solr.client.solrj.io.SolrClientCache;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.comp.StreamComparator;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
import org.apache.solr.client.solrj.request.QueryRequest;
|
||||
import org.apache.solr.client.solrj.response.QueryResponse;
|
||||
import org.apache.solr.common.cloud.ClusterState;
|
||||
import org.apache.solr.common.cloud.Replica;
|
||||
import org.apache.solr.common.cloud.Slice;
|
||||
import org.apache.solr.common.cloud.ZkCoreNodeProps;
|
||||
import org.apache.solr.common.cloud.ZkStateReader;
|
||||
import org.apache.solr.common.params.ModifiableSolrParams;
|
||||
import org.apache.solr.common.util.ExecutorUtil;
|
||||
import org.apache.solr.common.util.NamedList;
|
||||
import org.apache.solr.common.util.SolrjNamedThreadFactory;
|
||||
|
||||
public class FeaturesSelectionStream extends TupleStream implements Expressible{
|
||||
|
||||
private static final long serialVersionUID = 1;
|
||||
|
||||
protected String zkHost;
|
||||
protected String collection;
|
||||
protected Map<String,String> params;
|
||||
protected Iterator<Tuple> tupleIterator;
|
||||
protected String field;
|
||||
protected String outcome;
|
||||
protected String featureSet;
|
||||
protected int positiveLabel;
|
||||
protected int numTerms;
|
||||
|
||||
|
||||
|
||||
protected transient SolrClientCache cache;
|
||||
protected transient boolean isCloseCache;
|
||||
protected transient CloudSolrClient cloudSolrClient;
|
||||
|
||||
protected transient StreamContext streamContext;
|
||||
protected ExecutorService executorService;
|
||||
|
||||
|
||||
public FeaturesSelectionStream(String zkHost,
|
||||
String collectionName,
|
||||
Map params,
|
||||
String field,
|
||||
String outcome,
|
||||
String featureSet,
|
||||
int positiveLabel,
|
||||
int numTerms) throws IOException {
|
||||
|
||||
init(collectionName, zkHost, params, field, outcome, featureSet, positiveLabel, numTerms);
|
||||
}
|
||||
|
||||
/**
|
||||
* logit(collection, zkHost="", features="a,b,c,d,e,f,g", outcome="y", maxIteration="20")
|
||||
**/
|
||||
|
||||
public FeaturesSelectionStream(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
// grab all parameters out
|
||||
String collectionName = factory.getValueOperand(expression, 0);
|
||||
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
|
||||
StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");
|
||||
|
||||
// Validate there are no unknown parameters - zkHost and alias are namedParameter so we don't need to count it twice
|
||||
if(expression.getParameters().size() != 1 + namedParams.size()){
|
||||
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - unknown operands found",expression));
|
||||
}
|
||||
|
||||
// Collection Name
|
||||
if(null == collectionName){
|
||||
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression));
|
||||
}
|
||||
|
||||
// Named parameters - passed directly to solr as solrparams
|
||||
if(0 == namedParams.size()){
|
||||
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one named parameter expected. eg. 'q=*:*'",expression));
|
||||
}
|
||||
|
||||
Map<String,String> params = new HashMap<String,String>();
|
||||
for(StreamExpressionNamedParameter namedParam : namedParams){
|
||||
if(!namedParam.getName().equals("zkHost")) {
|
||||
params.put(namedParam.getName(), namedParam.getParameter().toString().trim());
|
||||
}
|
||||
}
|
||||
|
||||
String fieldParam = params.get("field");
|
||||
if(fieldParam != null) {
|
||||
params.remove("field");
|
||||
} else {
|
||||
throw new IOException("field param cannot be null for FeaturesSelectionStream");
|
||||
}
|
||||
|
||||
String outcomeParam = params.get("outcome");
|
||||
if(outcomeParam != null) {
|
||||
params.remove("outcome");
|
||||
} else {
|
||||
throw new IOException("outcome param cannot be null for FeaturesSelectionStream");
|
||||
}
|
||||
|
||||
String featureSetParam = params.get("featureSet");
|
||||
if(featureSetParam != null) {
|
||||
params.remove("featureSet");
|
||||
} else {
|
||||
throw new IOException("featureSet param cannot be null for FeaturesSelectionStream");
|
||||
}
|
||||
|
||||
String positiveLabelParam = params.get("positiveLabel");
|
||||
int positiveLabel = 1;
|
||||
if(positiveLabelParam != null) {
|
||||
params.remove("positiveLabel");
|
||||
positiveLabel = Integer.parseInt(positiveLabelParam);
|
||||
}
|
||||
|
||||
String numTermsParam = params.get("numTerms");
|
||||
int numTerms = 1;
|
||||
if(numTermsParam != null) {
|
||||
numTerms = Integer.parseInt(numTermsParam);
|
||||
params.remove("numTerms");
|
||||
} else {
|
||||
throw new IOException("numTerms param cannot be null for FeaturesSelectionStream");
|
||||
}
|
||||
|
||||
// zkHost, optional - if not provided then will look into factory list to get
|
||||
String zkHost = null;
|
||||
if(null == zkHostExpression){
|
||||
zkHost = factory.getCollectionZkHost(collectionName);
|
||||
}
|
||||
else if(zkHostExpression.getParameter() instanceof StreamExpressionValue){
|
||||
zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue();
|
||||
}
|
||||
if(null == zkHost){
|
||||
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName));
|
||||
}
|
||||
|
||||
// We've got all the required items
|
||||
init(collectionName, zkHost, params, fieldParam, outcomeParam, featureSetParam, positiveLabel, numTerms);
|
||||
}
|
||||
|
||||
@Override
|
||||
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
|
||||
// functionName(collectionName, param1, param2, ..., paramN, sort="comp", [aliases="field=alias,..."])
|
||||
|
||||
// function name
|
||||
StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
|
||||
|
||||
// collection
|
||||
expression.addParameter(collection);
|
||||
|
||||
// parameters
|
||||
for(Map.Entry<String,String> param : params.entrySet()){
|
||||
expression.addParameter(new StreamExpressionNamedParameter(param.getKey(), param.getValue()));
|
||||
}
|
||||
|
||||
expression.addParameter(new StreamExpressionNamedParameter("field", field));
|
||||
expression.addParameter(new StreamExpressionNamedParameter("outcome", outcome));
|
||||
expression.addParameter(new StreamExpressionNamedParameter("featureSet", featureSet));
|
||||
expression.addParameter(new StreamExpressionNamedParameter("positiveLabel", String.valueOf(positiveLabel)));
|
||||
expression.addParameter(new StreamExpressionNamedParameter("numTerms", String.valueOf(numTerms)));
|
||||
|
||||
// zkHost
|
||||
expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost));
|
||||
|
||||
return expression;
|
||||
}
|
||||
|
||||
private void init(String collectionName,
|
||||
String zkHost,
|
||||
Map params,
|
||||
String field,
|
||||
String outcome,
|
||||
String featureSet,
|
||||
int positiveLabel, int numTopTerms) throws IOException {
|
||||
this.zkHost = zkHost;
|
||||
this.collection = collectionName;
|
||||
this.params = params;
|
||||
this.field = field;
|
||||
this.outcome = outcome;
|
||||
this.featureSet = featureSet;
|
||||
this.positiveLabel = positiveLabel;
|
||||
this.numTerms = numTopTerms;
|
||||
}
|
||||
|
||||
public void setStreamContext(StreamContext context) {
|
||||
this.cache = context.getSolrClientCache();
|
||||
this.streamContext = context;
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens the CloudSolrStream
|
||||
*
|
||||
***/
|
||||
public void open() throws IOException {
|
||||
if (cache == null) {
|
||||
isCloseCache = true;
|
||||
cache = new SolrClientCache();
|
||||
} else {
|
||||
isCloseCache = false;
|
||||
}
|
||||
|
||||
this.cloudSolrClient = this.cache.getCloudSolrClient(zkHost);
|
||||
this.executorService = ExecutorUtil.newMDCAwareCachedThreadPool(new SolrjNamedThreadFactory("FeaturesSelectionStream"));
|
||||
}
|
||||
|
||||
public List<TupleStream> children() {
|
||||
return null;
|
||||
}
|
||||
|
||||
private List<String> getShardUrls() throws IOException {
|
||||
|
||||
try {
|
||||
|
||||
ZkStateReader zkStateReader = cloudSolrClient.getZkStateReader();
|
||||
ClusterState clusterState = zkStateReader.getClusterState();
|
||||
|
||||
Collection<Slice> slices = clusterState.getActiveSlices(this.collection);
|
||||
Set<String> liveNodes = clusterState.getLiveNodes();
|
||||
|
||||
List<String> baseUrls = new ArrayList<>();
|
||||
|
||||
for(Slice slice : slices) {
|
||||
Collection<Replica> replicas = slice.getReplicas();
|
||||
List<Replica> shuffler = new ArrayList<>();
|
||||
for(Replica replica : replicas) {
|
||||
if(replica.getState() == Replica.State.ACTIVE && liveNodes.contains(replica.getNodeName())) {
|
||||
shuffler.add(replica);
|
||||
}
|
||||
}
|
||||
|
||||
Collections.shuffle(shuffler, new Random());
|
||||
Replica rep = shuffler.get(0);
|
||||
ZkCoreNodeProps zkProps = new ZkCoreNodeProps(rep);
|
||||
String url = zkProps.getCoreUrl();
|
||||
baseUrls.add(url);
|
||||
}
|
||||
|
||||
return baseUrls;
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new IOException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private List<Future<NamedList>> callShards(List<String> baseUrls) throws IOException {
|
||||
|
||||
List<Future<NamedList>> futures = new ArrayList<>();
|
||||
for (String baseUrl : baseUrls) {
|
||||
FeaturesSelectionCall lc = new FeaturesSelectionCall(baseUrl,
|
||||
this.params,
|
||||
this.field,
|
||||
this.outcome);
|
||||
|
||||
Future<NamedList> future = executorService.submit(lc);
|
||||
futures.add(future);
|
||||
}
|
||||
|
||||
return futures;
|
||||
}
|
||||
|
||||
public void close() throws IOException {
|
||||
if (isCloseCache) {
|
||||
cache.close();
|
||||
}
|
||||
|
||||
executorService.shutdown();
|
||||
}
|
||||
|
||||
/** Return the stream sort - ie, the order in which records are returned */
|
||||
public StreamComparator getStreamSort(){
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation toExplanation(StreamFactory factory) throws IOException {
|
||||
return new StreamExplanation(getStreamNodeId().toString())
|
||||
.withFunctionName(factory.getFunctionName(this.getClass()))
|
||||
.withImplementingClass(this.getClass().getName())
|
||||
.withExpressionType(Explanation.ExpressionType.STREAM_DECORATOR)
|
||||
.withExpression(toExpression(factory).toString());
|
||||
}
|
||||
|
||||
public Tuple read() throws IOException {
|
||||
try {
|
||||
if (tupleIterator == null) {
|
||||
Map<String, Double> termScores = new HashMap<>();
|
||||
Map<String, Long> docFreqs = new HashMap<>();
|
||||
|
||||
|
||||
long numDocs = 0;
|
||||
for (Future<NamedList> getTopTermsCall : callShards(getShardUrls())) {
|
||||
NamedList resp = getTopTermsCall.get();
|
||||
|
||||
NamedList<Double> shardTopTerms = (NamedList<Double>)resp.get("featuredTerms");
|
||||
NamedList<Integer> shardDocFreqs = (NamedList<Integer>)resp.get("docFreq");
|
||||
|
||||
numDocs += (Integer)resp.get("numDocs");
|
||||
|
||||
for (int i = 0; i < shardTopTerms.size(); i++) {
|
||||
String term = shardTopTerms.getName(i);
|
||||
double score = shardTopTerms.getVal(i);
|
||||
int docFreq = shardDocFreqs.get(term);
|
||||
double prevScore = termScores.containsKey(term) ? termScores.get(term) : 0;
|
||||
long prevDocFreq = docFreqs.containsKey(term) ? docFreqs.get(term) : 0;
|
||||
termScores.put(term, prevScore + score);
|
||||
docFreqs.put(term, prevDocFreq + docFreq);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
List<Tuple> tuples = new ArrayList<>(numTerms);
|
||||
termScores = sortByValue(termScores);
|
||||
int index = 0;
|
||||
for (Map.Entry<String, Double> termScore : termScores.entrySet()) {
|
||||
if (tuples.size() == numTerms) break;
|
||||
index++;
|
||||
Map map = new HashMap();
|
||||
map.put("id", featureSet + "_" + index);
|
||||
map.put("index_i", index);
|
||||
map.put("term_s", termScore.getKey());
|
||||
map.put("score_f", termScore.getValue());
|
||||
map.put("featureSet_s", featureSet);
|
||||
long docFreq = docFreqs.get(termScore.getKey());
|
||||
double d = Math.log(((double)numDocs / (double)(docFreq + 1)));
|
||||
map.put("idf_d", d);
|
||||
tuples.add(new Tuple(map));
|
||||
}
|
||||
|
||||
Map map = new HashMap();
|
||||
map.put("EOF", true);
|
||||
tuples.add(new Tuple(map));
|
||||
|
||||
tupleIterator = tuples.iterator();
|
||||
}
|
||||
|
||||
return tupleIterator.next();
|
||||
} catch(Exception e) {
|
||||
throw new IOException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private <K, V extends Comparable<? super V>> Map<K, V> sortByValue( Map<K, V> map )
|
||||
{
|
||||
Map<K, V> result = new LinkedHashMap<>();
|
||||
Stream<Map.Entry<K, V>> st = map.entrySet().stream();
|
||||
|
||||
st.sorted( Map.Entry.comparingByValue(
|
||||
(c1, c2) -> c2.compareTo(c1)
|
||||
) ).forEachOrdered( e -> result.put(e.getKey(), e.getValue()) );
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
protected class FeaturesSelectionCall implements Callable<NamedList> {
|
||||
|
||||
private String baseUrl;
|
||||
private String outcome;
|
||||
private String field;
|
||||
private Map<String, String> paramsMap;
|
||||
|
||||
public FeaturesSelectionCall(String baseUrl,
|
||||
Map<String, String> paramsMap,
|
||||
String field,
|
||||
String outcome) {
|
||||
|
||||
this.baseUrl = baseUrl;
|
||||
this.outcome = outcome;
|
||||
this.field = field;
|
||||
this.paramsMap = paramsMap;
|
||||
}
|
||||
|
||||
public NamedList<Double> call() throws Exception {
|
||||
ModifiableSolrParams params = new ModifiableSolrParams();
|
||||
HttpSolrClient solrClient = cache.getHttpSolrClient(baseUrl);
|
||||
|
||||
params.add("distrib", "false");
|
||||
params.add("fq","{!igain}");
|
||||
|
||||
for(String key : paramsMap.keySet()) {
|
||||
params.add(key, paramsMap.get(key));
|
||||
}
|
||||
|
||||
params.add("outcome", outcome);
|
||||
params.add("positiveLabel", Integer.toString(positiveLabel));
|
||||
params.add("field", field);
|
||||
params.add("numTerms", String.valueOf(numTerms));
|
||||
|
||||
QueryRequest request= new QueryRequest(params);
|
||||
QueryResponse response = request.process(solrClient);
|
||||
NamedList res = response.getResponse();
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,657 @@
|
|||
package org.apache.solr.client.solrj.io.stream;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Future;
|
||||
|
||||
import org.apache.solr.client.solrj.SolrRequest;
|
||||
import org.apache.solr.client.solrj.SolrServerException;
|
||||
import org.apache.solr.client.solrj.impl.CloudSolrClient;
|
||||
import org.apache.solr.client.solrj.impl.HttpSolrClient;
|
||||
import org.apache.solr.client.solrj.io.ClassificationEvaluation;
|
||||
import org.apache.solr.client.solrj.io.SolrClientCache;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.comp.StreamComparator;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
import org.apache.solr.client.solrj.request.QueryRequest;
|
||||
import org.apache.solr.client.solrj.response.QueryResponse;
|
||||
import org.apache.solr.common.cloud.ClusterState;
|
||||
import org.apache.solr.common.cloud.Replica;
|
||||
import org.apache.solr.common.cloud.Slice;
|
||||
import org.apache.solr.common.cloud.ZkCoreNodeProps;
|
||||
import org.apache.solr.common.cloud.ZkStateReader;
|
||||
import org.apache.solr.common.params.ModifiableSolrParams;
|
||||
import org.apache.solr.common.util.ExecutorUtil;
|
||||
import org.apache.solr.common.util.NamedList;
|
||||
import org.apache.solr.common.util.SolrjNamedThreadFactory;
|
||||
|
||||
public class TextLogitStream extends TupleStream implements Expressible {
|
||||
|
||||
private static final long serialVersionUID = 1;
|
||||
|
||||
protected String zkHost;
|
||||
protected String collection;
|
||||
protected Map<String,String> params;
|
||||
protected String field;
|
||||
protected String name;
|
||||
protected String outcome;
|
||||
protected int positiveLabel;
|
||||
protected double threshold;
|
||||
protected List<Double> weights;
|
||||
protected int maxIterations;
|
||||
protected int iteration;
|
||||
protected double error;
|
||||
protected List<Double> idfs;
|
||||
protected ClassificationEvaluation evaluation;
|
||||
|
||||
protected transient SolrClientCache cache;
|
||||
protected transient boolean isCloseCache;
|
||||
protected transient CloudSolrClient cloudSolrClient;
|
||||
|
||||
protected transient StreamContext streamContext;
|
||||
protected ExecutorService executorService;
|
||||
protected TupleStream termsStream;
|
||||
private List<String> terms;
|
||||
|
||||
private double learningRate = 0.01;
|
||||
private double lastError = 0;
|
||||
|
||||
public TextLogitStream(String zkHost,
|
||||
String collectionName,
|
||||
Map params,
|
||||
String name,
|
||||
String field,
|
||||
TupleStream termsStream,
|
||||
List<Double> weights,
|
||||
String outcome,
|
||||
int positiveLabel,
|
||||
double threshold,
|
||||
int maxIterations) throws IOException {
|
||||
|
||||
init(collectionName, zkHost, params, name, field, termsStream, weights, outcome, positiveLabel, threshold, maxIterations, iteration);
|
||||
}
|
||||
|
||||
/**
|
||||
* logit(collection, zkHost="", features="a,b,c,d,e,f,g", outcome="y", maxIteration="20")
|
||||
**/
|
||||
|
||||
public TextLogitStream(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
// grab all parameters out
|
||||
String collectionName = factory.getValueOperand(expression, 0);
|
||||
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
|
||||
StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");
|
||||
List<StreamExpression> streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class);
|
||||
|
||||
// Validate there are no unknown parameters - zkHost and alias are namedParameter so we don't need to count it twice
|
||||
if(expression.getParameters().size() != 1 + namedParams.size() + streamExpressions.size()){
|
||||
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - unknown operands found",expression));
|
||||
}
|
||||
|
||||
// Collection Name
|
||||
if(null == collectionName){
|
||||
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression));
|
||||
}
|
||||
|
||||
// Named parameters - passed directly to solr as solrparams
|
||||
if(0 == namedParams.size()){
|
||||
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one named parameter expected. eg. 'q=*:*'",expression));
|
||||
}
|
||||
|
||||
Map<String,String> params = new HashMap<String,String>();
|
||||
for(StreamExpressionNamedParameter namedParam : namedParams){
|
||||
if(!namedParam.getName().equals("zkHost")) {
|
||||
params.put(namedParam.getName(), namedParam.getParameter().toString().trim());
|
||||
}
|
||||
}
|
||||
|
||||
String name = params.get("name");
|
||||
if (name != null) {
|
||||
params.remove("name");
|
||||
} else {
|
||||
throw new IOException("name param cannot be null for TextLogitStream");
|
||||
}
|
||||
|
||||
String feature = params.get("field");
|
||||
if (feature != null) {
|
||||
params.remove("field");
|
||||
} else {
|
||||
throw new IOException("field param cannot be null for TextLogitStream");
|
||||
}
|
||||
|
||||
TupleStream stream = null;
|
||||
|
||||
if (streamExpressions.size() > 0) {
|
||||
stream = factory.constructStream(streamExpressions.get(0));
|
||||
} else {
|
||||
throw new IOException("features must be present for TextLogitStream");
|
||||
}
|
||||
|
||||
String maxIterationsParam = params.get("maxIterations");
|
||||
int maxIterations = 0;
|
||||
if(maxIterationsParam != null) {
|
||||
maxIterations = Integer.parseInt(maxIterationsParam);
|
||||
params.remove("maxIterations");
|
||||
} else {
|
||||
throw new IOException("maxIterations param cannot be null for TextLogitStream");
|
||||
}
|
||||
|
||||
String outcomeParam = params.get("outcome");
|
||||
|
||||
if(outcomeParam != null) {
|
||||
params.remove("outcome");
|
||||
} else {
|
||||
throw new IOException("outcome param cannot be null for TextLogitStream");
|
||||
}
|
||||
|
||||
String positiveLabelParam = params.get("positiveLabel");
|
||||
int positiveLabel = 1;
|
||||
if(positiveLabelParam != null) {
|
||||
positiveLabel = Integer.parseInt(positiveLabelParam);
|
||||
params.remove("positiveLabel");
|
||||
}
|
||||
|
||||
String thresholdParam = params.get("threshold");
|
||||
double threshold = 0.5;
|
||||
if(thresholdParam != null) {
|
||||
threshold = Double.parseDouble(thresholdParam);
|
||||
params.remove("threshold");
|
||||
}
|
||||
|
||||
int iteration = 0;
|
||||
String iterationParam = params.get("iteration");
|
||||
if(iterationParam != null) {
|
||||
iteration = Integer.parseInt(iterationParam);
|
||||
params.remove("iteration");
|
||||
}
|
||||
|
||||
List<Double> weights = null;
|
||||
String weightsParam = params.get("weights");
|
||||
if(weightsParam != null) {
|
||||
weights = new ArrayList<>();
|
||||
String[] weightsArray = weightsParam.split(",");
|
||||
for(String weightString : weightsArray) {
|
||||
weights.add(Double.parseDouble(weightString));
|
||||
}
|
||||
params.remove("weights");
|
||||
}
|
||||
|
||||
// zkHost, optional - if not provided then will look into factory list to get
|
||||
String zkHost = null;
|
||||
if(null == zkHostExpression){
|
||||
zkHost = factory.getCollectionZkHost(collectionName);
|
||||
}
|
||||
else if(zkHostExpression.getParameter() instanceof StreamExpressionValue){
|
||||
zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue();
|
||||
}
|
||||
if(null == zkHost){
|
||||
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName));
|
||||
}
|
||||
|
||||
// We've got all the required items
|
||||
init(collectionName, zkHost, params, name, feature, stream, weights, outcomeParam, positiveLabel, threshold, maxIterations, iteration);
|
||||
}
|
||||
|
||||
@Override
|
||||
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
|
||||
return toExpression(factory, true);
|
||||
}
|
||||
|
||||
private StreamExpression toExpression(StreamFactory factory, boolean includeStreams) throws IOException {
|
||||
// function name
|
||||
StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
|
||||
|
||||
// collection
|
||||
expression.addParameter(collection);
|
||||
|
||||
if (includeStreams && !(termsStream instanceof TermsStream)) {
|
||||
if (termsStream instanceof Expressible) {
|
||||
expression.addParameter(((Expressible)termsStream).toExpression(factory));
|
||||
} else {
|
||||
throw new IOException("This TextLogitStream contains a non-expressible TupleStream - it cannot be converted to an expression");
|
||||
}
|
||||
}
|
||||
|
||||
// parameters
|
||||
for(Entry<String,String> param : params.entrySet()){
|
||||
expression.addParameter(new StreamExpressionNamedParameter(param.getKey(), param.getValue()));
|
||||
}
|
||||
|
||||
expression.addParameter(new StreamExpressionNamedParameter("field", field));
|
||||
expression.addParameter(new StreamExpressionNamedParameter("name", name));
|
||||
if (termsStream instanceof TermsStream) {
|
||||
loadTerms();
|
||||
expression.addParameter(new StreamExpressionNamedParameter("terms", toString(terms)));
|
||||
}
|
||||
|
||||
expression.addParameter(new StreamExpressionNamedParameter("outcome", outcome));
|
||||
if(weights != null) {
|
||||
expression.addParameter(new StreamExpressionNamedParameter("weights", toString(weights)));
|
||||
}
|
||||
expression.addParameter(new StreamExpressionNamedParameter("maxIterations", Integer.toString(maxIterations)));
|
||||
|
||||
if(iteration > 0) {
|
||||
expression.addParameter(new StreamExpressionNamedParameter("iteration", Integer.toString(iteration)));
|
||||
}
|
||||
|
||||
expression.addParameter(new StreamExpressionNamedParameter("positiveLabel", Integer.toString(positiveLabel)));
|
||||
expression.addParameter(new StreamExpressionNamedParameter("threshold", Double.toString(threshold)));
|
||||
|
||||
// zkHost
|
||||
expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost));
|
||||
|
||||
return expression;
|
||||
}
|
||||
|
||||
private void init(String collectionName,
|
||||
String zkHost,
|
||||
Map params,
|
||||
String name,
|
||||
String feature,
|
||||
TupleStream termsStream,
|
||||
List<Double> weights,
|
||||
String outcome,
|
||||
int positiveLabel,
|
||||
double threshold,
|
||||
int maxIterations,
|
||||
int iteration) throws IOException {
|
||||
this.zkHost = zkHost;
|
||||
this.collection = collectionName;
|
||||
this.params = params;
|
||||
this.name = name;
|
||||
this.field = feature;
|
||||
this.termsStream = termsStream;
|
||||
this.outcome = outcome;
|
||||
this.positiveLabel = positiveLabel;
|
||||
this.threshold = threshold;
|
||||
this.weights = weights;
|
||||
this.maxIterations = maxIterations;
|
||||
this.iteration = iteration;
|
||||
}
|
||||
|
||||
public void setStreamContext(StreamContext context) {
|
||||
this.cache = context.getSolrClientCache();
|
||||
this.streamContext = context;
|
||||
this.termsStream.setStreamContext(context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens the CloudSolrStream
|
||||
*
|
||||
***/
|
||||
public void open() throws IOException {
|
||||
if (cache == null) {
|
||||
isCloseCache = true;
|
||||
cache = new SolrClientCache();
|
||||
} else {
|
||||
isCloseCache = false;
|
||||
}
|
||||
|
||||
this.cloudSolrClient = this.cache.getCloudSolrClient(zkHost);
|
||||
this.executorService = ExecutorUtil.newMDCAwareCachedThreadPool(new SolrjNamedThreadFactory("TextLogitSolrStream"));
|
||||
}
|
||||
|
||||
public List<TupleStream> children() {
|
||||
List<TupleStream> l = new ArrayList();
|
||||
l.add(termsStream);
|
||||
return l;
|
||||
}
|
||||
|
||||
protected List<String> getShardUrls() throws IOException {
|
||||
|
||||
try {
|
||||
|
||||
ZkStateReader zkStateReader = cloudSolrClient.getZkStateReader();
|
||||
ClusterState clusterState = zkStateReader.getClusterState();
|
||||
Set<String> liveNodes = clusterState.getLiveNodes();
|
||||
|
||||
Collection<Slice> slices = clusterState.getActiveSlices(this.collection);
|
||||
List baseUrls = new ArrayList();
|
||||
|
||||
for(Slice slice : slices) {
|
||||
Collection<Replica> replicas = slice.getReplicas();
|
||||
List<Replica> shuffler = new ArrayList();
|
||||
for(Replica replica : replicas) {
|
||||
if(replica.getState() == Replica.State.ACTIVE && liveNodes.contains(replica.getNodeName())) {
|
||||
shuffler.add(replica);
|
||||
}
|
||||
}
|
||||
|
||||
Collections.shuffle(shuffler, new Random());
|
||||
Replica rep = shuffler.get(0);
|
||||
ZkCoreNodeProps zkProps = new ZkCoreNodeProps(rep);
|
||||
String url = zkProps.getCoreUrl();
|
||||
baseUrls.add(url);
|
||||
}
|
||||
|
||||
return baseUrls;
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new IOException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private List<Future<Tuple>> callShards(List<String> baseUrls) throws IOException {
|
||||
|
||||
List<Future<Tuple>> futures = new ArrayList();
|
||||
for (String baseUrl : baseUrls) {
|
||||
LogitCall lc = new LogitCall(baseUrl,
|
||||
this.params,
|
||||
this.field,
|
||||
this.terms,
|
||||
this.weights,
|
||||
this.outcome,
|
||||
this.positiveLabel,
|
||||
this.learningRate,
|
||||
this.iteration);
|
||||
|
||||
Future<Tuple> future = executorService.submit(lc);
|
||||
futures.add(future);
|
||||
}
|
||||
|
||||
return futures;
|
||||
}
|
||||
|
||||
public void close() throws IOException {
|
||||
if (isCloseCache) {
|
||||
cache.close();
|
||||
}
|
||||
|
||||
executorService.shutdown();
|
||||
termsStream.close();
|
||||
}
|
||||
|
||||
/** Return the stream sort - ie, the order in which records are returned */
|
||||
public StreamComparator getStreamSort(){
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation toExplanation(StreamFactory factory) throws IOException {
|
||||
StreamExplanation explanation = new StreamExplanation(getStreamNodeId().toString());
|
||||
explanation.setFunctionName(factory.getFunctionName(this.getClass()));
|
||||
explanation.setImplementingClass(this.getClass().getName());
|
||||
explanation.setExpressionType(Explanation.ExpressionType.MACHINE_LEARNING_MODEL);
|
||||
explanation.setExpression(toExpression(factory).toString());
|
||||
|
||||
explanation.addChild(termsStream.toExplanation(factory));
|
||||
|
||||
return explanation;
|
||||
}
|
||||
|
||||
public void loadTerms() throws IOException {
|
||||
if (this.terms == null) {
|
||||
termsStream.open();
|
||||
this.terms = new ArrayList<>();
|
||||
this.idfs = new ArrayList();
|
||||
|
||||
while (true) {
|
||||
Tuple termTuple = termsStream.read();
|
||||
if (termTuple.EOF) {
|
||||
break;
|
||||
} else {
|
||||
terms.add(termTuple.getString("term_s"));
|
||||
idfs.add(termTuple.getDouble("idf_d"));
|
||||
}
|
||||
}
|
||||
termsStream.close();
|
||||
}
|
||||
}
|
||||
|
||||
public Tuple read() throws IOException {
|
||||
try {
|
||||
|
||||
if(++iteration > maxIterations) {
|
||||
Map map = new HashMap();
|
||||
map.put("EOF", true);
|
||||
return new Tuple(map);
|
||||
} else {
|
||||
|
||||
if (this.idfs == null) {
|
||||
loadTerms();
|
||||
|
||||
if (weights != null && terms.size() + 1 != weights.size()) {
|
||||
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - the number of weights must be %d, found %d", terms.size()+1, weights.size()));
|
||||
}
|
||||
}
|
||||
|
||||
List<List<Double>> allWeights = new ArrayList();
|
||||
this.evaluation = new ClassificationEvaluation();
|
||||
|
||||
this.error = 0;
|
||||
for (Future<Tuple> logitCall : callShards(getShardUrls())) {
|
||||
|
||||
Tuple tuple = logitCall.get();
|
||||
List<Double> shardWeights = (List<Double>) tuple.get("weights");
|
||||
allWeights.add(shardWeights);
|
||||
this.error += tuple.getDouble("error");
|
||||
Map shardEvaluation = (Map) tuple.get("evaluation");
|
||||
this.evaluation.addEvaluation(shardEvaluation);
|
||||
}
|
||||
|
||||
this.weights = averageWeights(allWeights);
|
||||
Map map = new HashMap();
|
||||
map.put("id", name+"_"+iteration);
|
||||
map.put("name_s", name);
|
||||
map.put("field_s", field);
|
||||
map.put("terms_ss", terms);
|
||||
map.put("iteration_i", iteration);
|
||||
|
||||
if(weights != null) {
|
||||
map.put("weights_ds", weights);
|
||||
}
|
||||
|
||||
map.put("error_d", error);
|
||||
evaluation.putToMap(map);
|
||||
map.put("alpha_d", this.learningRate);
|
||||
map.put("idfs_ds", this.idfs);
|
||||
|
||||
if (iteration != 1) {
|
||||
if (lastError <= error) {
|
||||
this.learningRate *= 0.5;
|
||||
} else {
|
||||
this.learningRate *= 1.05;
|
||||
}
|
||||
}
|
||||
|
||||
lastError = error;
|
||||
|
||||
return new Tuple(map);
|
||||
}
|
||||
|
||||
} catch(Exception e) {
|
||||
throw new IOException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private List<Double> averageWeights(List<List<Double>> allWeights) {
|
||||
double[] working = new double[allWeights.get(0).size()];
|
||||
for(List<Double> shardWeights: allWeights) {
|
||||
for(int i=0; i<working.length; i++) {
|
||||
working[i] += shardWeights.get(i);
|
||||
}
|
||||
}
|
||||
|
||||
for(int i=0; i<working.length; i++) {
|
||||
working[i] = working[i] / allWeights.size();
|
||||
}
|
||||
|
||||
List<Double> ave = new ArrayList();
|
||||
for(double d : working) {
|
||||
ave.add(d);
|
||||
}
|
||||
|
||||
return ave;
|
||||
}
|
||||
|
||||
static String toString(List items) {
|
||||
StringBuilder buf = new StringBuilder();
|
||||
for(Object item : items) {
|
||||
if(buf.length() > 0) {
|
||||
buf.append(",");
|
||||
}
|
||||
|
||||
buf.append(item.toString());
|
||||
}
|
||||
|
||||
return buf.toString();
|
||||
}
|
||||
|
||||
protected class TermsStream extends TupleStream {
|
||||
|
||||
private List<String> terms;
|
||||
private Iterator<String> it;
|
||||
|
||||
public TermsStream(List<String> terms) {
|
||||
this.terms = terms;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setStreamContext(StreamContext context) {}
|
||||
|
||||
@Override
|
||||
public List<TupleStream> children() { return new ArrayList<>(); }
|
||||
|
||||
@Override
|
||||
public void open() throws IOException { this.it = this.terms.iterator();}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {}
|
||||
|
||||
@Override
|
||||
public Tuple read() throws IOException {
|
||||
HashMap map = new HashMap();
|
||||
if(it.hasNext()) {
|
||||
map.put("term_s",it.next());
|
||||
map.put("score_f",1.0);
|
||||
return new Tuple(map);
|
||||
} else {
|
||||
map.put("EOF", true);
|
||||
return new Tuple(map);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public StreamComparator getStreamSort() {return null;}
|
||||
|
||||
@Override
|
||||
public Explanation toExplanation(StreamFactory factory) throws IOException {
|
||||
return new StreamExplanation(getStreamNodeId().toString())
|
||||
.withFunctionName("non-expressible")
|
||||
.withImplementingClass(this.getClass().getName())
|
||||
.withExpressionType(Explanation.ExpressionType.STREAM_SOURCE)
|
||||
.withExpression("non-expressible");
|
||||
}
|
||||
}
|
||||
|
||||
protected class LogitCall implements Callable<Tuple> {
|
||||
|
||||
private String baseUrl;
|
||||
private String feature;
|
||||
private List<String> terms;
|
||||
private List<Double> weights;
|
||||
private int iteration;
|
||||
private String outcome;
|
||||
private int positiveLabel;
|
||||
private double learningRate;
|
||||
private Map<String, String> paramsMap;
|
||||
|
||||
public LogitCall(String baseUrl,
|
||||
Map<String, String> paramsMap,
|
||||
String feature,
|
||||
List<String> terms,
|
||||
List<Double> weights,
|
||||
String outcome,
|
||||
int positiveLabel,
|
||||
double learningRate,
|
||||
int iteration) {
|
||||
|
||||
this.baseUrl = baseUrl;
|
||||
this.feature = feature;
|
||||
this.terms = terms;
|
||||
this.weights = weights;
|
||||
this.iteration = iteration;
|
||||
this.outcome = outcome;
|
||||
this.positiveLabel = positiveLabel;
|
||||
this.learningRate = learningRate;
|
||||
this.paramsMap = paramsMap;
|
||||
}
|
||||
|
||||
public Tuple call() throws Exception {
|
||||
ModifiableSolrParams params = new ModifiableSolrParams();
|
||||
HttpSolrClient solrClient = cache.getHttpSolrClient(baseUrl);
|
||||
|
||||
params.add("distrib", "false");
|
||||
params.add("fq","{!tlogit}");
|
||||
params.add("feature", feature);
|
||||
params.add("terms", TextLogitStream.toString(terms));
|
||||
params.add("idfs", TextLogitStream.toString(idfs));
|
||||
|
||||
for(String key : paramsMap.keySet()) {
|
||||
params.add(key, paramsMap.get(key));
|
||||
}
|
||||
|
||||
if(weights != null) {
|
||||
params.add("weights", TextLogitStream.toString(weights));
|
||||
}
|
||||
|
||||
params.add("iteration", Integer.toString(iteration));
|
||||
params.add("outcome", outcome);
|
||||
params.add("positiveLabel", Integer.toString(positiveLabel));
|
||||
params.add("threshold", Double.toString(threshold));
|
||||
params.add("alpha", Double.toString(learningRate));
|
||||
|
||||
QueryRequest request= new QueryRequest(params, SolrRequest.METHOD.POST);
|
||||
QueryResponse response = request.process(solrClient);
|
||||
NamedList res = response.getResponse();
|
||||
|
||||
NamedList logit = (NamedList)res.get("logit");
|
||||
|
||||
List<Double> shardWeights = (List<Double>)logit.get("weights");
|
||||
double shardError = (double)logit.get("error");
|
||||
|
||||
Map map = new HashMap();
|
||||
|
||||
map.put("error", shardError);
|
||||
map.put("weights", shardWeights);
|
||||
map.put("evaluation", logit.get("evaluation"));
|
||||
|
||||
return new Tuple(map);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -147,6 +147,7 @@ public class Explanation {
|
|||
|
||||
public static interface ExpressionType{
|
||||
public static final String GRAPH_SOURCE = "graph-source";
|
||||
public static final String MACHINE_LEARNING_MODEL = "ml-model";
|
||||
public static final String STREAM_SOURCE = "stream-source";
|
||||
public static final String STREAM_DECORATOR = "stream-decorator";
|
||||
public static final String DATASTORE = "datastore";
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
<?xml version="1.0" ?>
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
contributor license agreements. See the NOTICE file distributed with
|
||||
this work for additional information regarding copyright ownership.
|
||||
The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
(the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
<!-- The Solr schema file. This file should be named "schema.xml" and
|
||||
should be located where the classloader for the Solr webapp can find it.
|
||||
|
||||
This schema is used for testing, and as such has everything and the
|
||||
kitchen sink thrown in. See example/solr/conf/schema.xml for a
|
||||
more concise example.
|
||||
|
||||
-->
|
||||
|
||||
<schema name="test" version="1.6">
|
||||
|
||||
<fieldType name="int" docValues="true" class="solr.TrieIntField" precisionStep="0" omitNorms="true" positionIncrementGap="0"/>
|
||||
<fieldType name="float" docValues="true" class="solr.TrieFloatField" precisionStep="0" omitNorms="true" positionIncrementGap="0"/>
|
||||
<fieldType name="long" class="solr.TrieLongField" precisionStep="0" omitNorms="true" positionIncrementGap="0"/>
|
||||
<fieldType name="double" class="solr.TrieDoubleField" precisionStep="0" omitNorms="true" positionIncrementGap="0"/>
|
||||
|
||||
<fieldType name="tint" class="solr.TrieIntField" precisionStep="8" omitNorms="true" positionIncrementGap="0"/>
|
||||
<fieldType name="tfloat" class="solr.TrieFloatField" precisionStep="8" omitNorms="true" positionIncrementGap="0"/>
|
||||
<fieldType name="tlong" class="solr.TrieLongField" precisionStep="8" omitNorms="true" positionIncrementGap="0"/>
|
||||
<fieldType name="tdouble" class="solr.TrieDoubleField" precisionStep="8" omitNorms="true" positionIncrementGap="0"/>
|
||||
|
||||
<fieldType name="random" class="solr.RandomSortField" indexed="true" />
|
||||
|
||||
<fieldtype name="boolean" class="solr.BoolField" sortMissingLast="true"/>
|
||||
<fieldtype name="string" class="solr.StrField" sortMissingLast="true" docValues="true"/>
|
||||
|
||||
<!-- format for date is 1995-12-31T23:59:59.999Z and only the fractional
|
||||
seconds part (.999) is optional.
|
||||
-->
|
||||
<fieldtype name="date" class="solr.TrieDateField" precisionStep="0"/>
|
||||
<fieldtype name="tdate" class="solr.TrieDateField" precisionStep="6"/>
|
||||
|
||||
|
||||
<field name="id" type="string" indexed="true" stored="true" multiValued="false" required="false"/>
|
||||
<field name="_version_" type="long" indexed="true" stored="true"/>
|
||||
|
||||
<!-- Dynamic field definitions. If a field name is not found, dynamicFields
|
||||
will be used if the name matches any of the patterns.
|
||||
RESTRICTION: the glob-like pattern in the name attribute must have
|
||||
a "*" only at the start or the end.
|
||||
EXAMPLE: name="*_i" will match any field ending in _i (like myid_i, z_i)
|
||||
Longer patterns will be matched first. if equal size patterns
|
||||
both match, the first appearing in the schema will be used.
|
||||
-->
|
||||
<dynamicField name="*_b" type="boolean" indexed="true" stored="true" multiValued="false"/>
|
||||
<dynamicField name="*_bs" type="boolean" indexed="true" stored="true" multiValued="true"/>
|
||||
<dynamicField name="*_i" type="int" indexed="true" stored="true" multiValued="false"/>
|
||||
<dynamicField name="*_is" type="int" indexed="true" stored="true" multiValued="true"/>
|
||||
<dynamicField name="*_l" type="long" indexed="true" stored="true" multiValued="false"/>
|
||||
<dynamicField name="*_ls" type="long" indexed="true" stored="true" multiValued="true"/>
|
||||
<dynamicField name="*_f" type="float" indexed="true" stored="true" multiValued="false"/>
|
||||
<dynamicField name="*_fs" type="float" indexed="true" stored="true" multiValued="true"/>
|
||||
<dynamicField name="*_d" type="double" indexed="true" stored="true" multiValued="false"/>
|
||||
<dynamicField name="*_ds" type="double" indexed="true" stored="true" multiValued="true"/>
|
||||
|
||||
<dynamicField name="*_s" type="string" indexed="true" stored="true" multiValued="false"/>
|
||||
<dynamicField name="*_ss" type="string" indexed="true" stored="true" multiValued="true"/>
|
||||
<uniqueKey>id</uniqueKey>
|
||||
</schema>
|
|
@ -0,0 +1,51 @@
|
|||
<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
contributor license agreements. See the NOTICE file distributed with
|
||||
this work for additional information regarding copyright ownership.
|
||||
The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
(the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
<!--
|
||||
This is a stripped down config file used for a simple example...
|
||||
It is *not* a good example to work from.
|
||||
-->
|
||||
<config>
|
||||
<luceneMatchVersion>${tests.luceneMatchVersion:LUCENE_CURRENT}</luceneMatchVersion>
|
||||
<indexConfig>
|
||||
<useCompoundFile>${useCompoundFile:false}</useCompoundFile>
|
||||
</indexConfig>
|
||||
<dataDir>${solr.data.dir:}</dataDir>
|
||||
<directoryFactory name="DirectoryFactory" class="${solr.directoryFactory:solr.StandardDirectoryFactory}"/>
|
||||
<schemaFactory class="ClassicIndexSchemaFactory"/>
|
||||
|
||||
<updateHandler class="solr.DirectUpdateHandler2">
|
||||
<updateLog>
|
||||
<str name="dir">${solr.data.dir:}</str>
|
||||
</updateLog>
|
||||
</updateHandler>
|
||||
|
||||
|
||||
<requestDispatcher handleSelect="true" >
|
||||
<requestParsers enableRemoteStreaming="false" multipartUploadLimitInKB="2048" />
|
||||
</requestDispatcher>
|
||||
|
||||
<requestHandler name="standard" class="solr.StandardRequestHandler" default="true" />
|
||||
|
||||
<!-- config for the admin interface -->
|
||||
<admin>
|
||||
<defaultQuery>solr</defaultQuery>
|
||||
</admin>
|
||||
|
||||
</config>
|
||||
|
|
@ -16,16 +16,23 @@
|
|||
*/
|
||||
package org.apache.solr.client.solrj.io.stream;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Enumeration;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipFile;
|
||||
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.LuceneTestCase.Slow;
|
||||
import org.apache.solr.client.solrj.embedded.JettySolrRunner;
|
||||
import org.apache.solr.client.solrj.io.ClassificationEvaluation;
|
||||
import org.apache.solr.client.solrj.io.SolrClientCache;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.comp.ComparatorOrder;
|
||||
|
@ -71,6 +78,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
public static void setupCluster() throws Exception {
|
||||
configureCluster(4)
|
||||
.addConfig("conf", getFile("solrj").toPath().resolve("solr").resolve("configsets").resolve("streaming").resolve("conf"))
|
||||
.addConfig("ml", getFile("solrj").toPath().resolve("solr").resolve("configsets").resolve("ml").resolve("conf"))
|
||||
.configure();
|
||||
|
||||
CollectionAdminRequest.createCollection(COLLECTION, "conf", 2, 1).process(cluster.getSolrClient());
|
||||
|
@ -2773,6 +2781,8 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
assert(tuple.getDouble("a_f") == 4.0);
|
||||
assertList(tuple.getStrings("s_multi"), "aaaa3", "bbbb3");
|
||||
assertList(tuple.getLongs("i_multi"), Long.parseLong("4444"), Long.parseLong("7777"));
|
||||
|
||||
CollectionAdminRequest.deleteCollection("destinationCollection").process(cluster.getSolrClient());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -2863,6 +2873,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
assertList(tuple.getStrings("s_multi"), "aaaa3", "bbbb3");
|
||||
assertList(tuple.getLongs("i_multi"), Long.parseLong("4444"), Long.parseLong("7777"));
|
||||
|
||||
CollectionAdminRequest.deleteCollection("parallelDestinationCollection").process(cluster.getSolrClient());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -3025,6 +3036,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
assertList(tuple.getStrings("s_multi"), "aaaa3", "bbbb3");
|
||||
assertList(tuple.getLongs("i_multi"), Long.parseLong("4444"), Long.parseLong("7777"));
|
||||
|
||||
CollectionAdminRequest.deleteCollection("parallelDestinationCollection1").process(cluster.getSolrClient());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -3065,6 +3077,121 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testBasicTextLogitStream() throws Exception {
|
||||
CollectionAdminRequest.createCollection("destinationCollection", "ml", 2, 1).process(cluster.getSolrClient());
|
||||
AbstractDistribZkTestBase.waitForRecoveriesToFinish("destinationCollection", cluster.getSolrClient().getZkStateReader(),
|
||||
false, true, TIMEOUT);
|
||||
|
||||
UpdateRequest updateRequest = new UpdateRequest();
|
||||
for (int i = 0; i < 5000; i+=2) {
|
||||
updateRequest.add(id, String.valueOf(i), "tv_text", "a b c c d", "out_i", "1");
|
||||
updateRequest.add(id, String.valueOf(i+1), "tv_text", "a b e e f", "out_i", "0");
|
||||
}
|
||||
updateRequest.commit(cluster.getSolrClient(), COLLECTION);
|
||||
|
||||
StreamExpression expression;
|
||||
TupleStream stream;
|
||||
List<Tuple> tuples;
|
||||
|
||||
StreamFactory factory = new StreamFactory()
|
||||
.withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress())
|
||||
.withCollectionZkHost("destinationCollection", cluster.getZkServer().getZkAddress())
|
||||
.withFunctionName("features", FeaturesSelectionStream.class)
|
||||
.withFunctionName("train", TextLogitStream.class)
|
||||
.withFunctionName("search", CloudSolrStream.class)
|
||||
.withFunctionName("update", UpdateStream.class);
|
||||
|
||||
expression = StreamExpressionParser.parse("features(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4)");
|
||||
stream = new FeaturesSelectionStream(expression, factory);
|
||||
tuples = getTuples(stream);
|
||||
|
||||
assert(tuples.size() == 4);
|
||||
HashSet<String> terms = new HashSet<>();
|
||||
for (Tuple tuple : tuples) {
|
||||
terms.add((String) tuple.get("term_s"));
|
||||
}
|
||||
assertTrue(terms.contains("d"));
|
||||
assertTrue(terms.contains("c"));
|
||||
assertTrue(terms.contains("e"));
|
||||
assertTrue(terms.contains("f"));
|
||||
|
||||
String textLogitExpression = "train(" +
|
||||
"collection1, " +
|
||||
"features(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4),"+
|
||||
"q=\"*:*\", " +
|
||||
"name=\"model\", " +
|
||||
"field=\"tv_text\", " +
|
||||
"outcome=\"out_i\", " +
|
||||
"maxIterations=100)";
|
||||
stream = factory.constructStream(textLogitExpression);
|
||||
tuples = getTuples(stream);
|
||||
Tuple lastTuple = tuples.get(tuples.size() - 1);
|
||||
List<Double> lastWeights = lastTuple.getDoubles("weights_ds");
|
||||
Double[] lastWeightsArray = lastWeights.toArray(new Double[lastWeights.size()]);
|
||||
|
||||
// first feature is bias value
|
||||
Double[] testRecord = {1.0, 1.17, 0.691, 0.0, 0.0};
|
||||
double d = sum(multiply(testRecord, lastWeightsArray));
|
||||
double prob = sigmoid(d);
|
||||
assertEquals(prob, 1.0, 0.1);
|
||||
|
||||
// first feature is bias value
|
||||
Double[] testRecord2 = {1.0, 0.0, 0.0, 1.17, 0.691};
|
||||
d = sum(multiply(testRecord2, lastWeightsArray));
|
||||
prob = sigmoid(d);
|
||||
assertEquals(prob, 0, 0.1);
|
||||
|
||||
stream = factory.constructStream("update(destinationCollection, batchSize=5, "+textLogitExpression+")");
|
||||
getTuples(stream);
|
||||
cluster.getSolrClient().commit("destinationCollection");
|
||||
|
||||
stream = factory.constructStream("search(destinationCollection, " +
|
||||
"q=*:*, " +
|
||||
"fl=\"iteration_i,* \", " +
|
||||
"rows=100, " +
|
||||
"sort=\"iteration_i desc\")");
|
||||
tuples = getTuples(stream);
|
||||
assertEquals(100, tuples.size());
|
||||
Tuple lastModel = tuples.get(0);
|
||||
ClassificationEvaluation evaluation = ClassificationEvaluation.create(lastModel.fields);
|
||||
assertTrue(evaluation.getF1() >= 1.0);
|
||||
assertEquals(Math.log( 5000.0 / (2500 + 1)), lastModel.getDoubles("idfs_ds").get(0), 0.0001);
|
||||
// make sure the tuples is retrieved in correct order
|
||||
Tuple firstTuple = tuples.get(99);
|
||||
assertEquals(1L, (long) firstTuple.getLong("iteration_i"));
|
||||
|
||||
CollectionAdminRequest.deleteCollection("destinationCollection").process(cluster.getSolrClient());
|
||||
}
|
||||
|
||||
private double sigmoid(double in) {
|
||||
|
||||
double d = 1.0 / (1+Math.exp(-in));
|
||||
return d;
|
||||
}
|
||||
|
||||
private double[] multiply(Double[] vec1, Double[] vec2) {
|
||||
double[] working = new double[vec1.length];
|
||||
for(int i=0; i<vec1.length; i++) {
|
||||
working[i]= vec1[i]*vec2[i];
|
||||
}
|
||||
|
||||
return working;
|
||||
}
|
||||
|
||||
|
||||
private double sum(double[] vec) {
|
||||
double d = 0.0;
|
||||
|
||||
for(double v : vec) {
|
||||
d += v;
|
||||
}
|
||||
|
||||
return d;
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testParallelIntersectStream() throws Exception {
|
||||
|
||||
|
@ -3103,6 +3230,62 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFeaturesSelectionStream() throws Exception {
|
||||
|
||||
CollectionAdminRequest.createCollection("destinationCollection", "ml", 2, 1).process(cluster.getSolrClient());
|
||||
AbstractDistribZkTestBase.waitForRecoveriesToFinish("destinationCollection", cluster.getSolrClient().getZkStateReader(),
|
||||
false, true, TIMEOUT);
|
||||
|
||||
UpdateRequest updateRequest = new UpdateRequest();
|
||||
for (int i = 0; i < 5000; i+=2) {
|
||||
updateRequest.add(id, String.valueOf(i), "whitetok", "a b c d", "out_i", "1");
|
||||
updateRequest.add(id, String.valueOf(i+1), "whitetok", "a b e f", "out_i", "0");
|
||||
}
|
||||
updateRequest.commit(cluster.getSolrClient(), COLLECTION);
|
||||
|
||||
StreamExpression expression;
|
||||
TupleStream stream;
|
||||
List<Tuple> tuples;
|
||||
|
||||
StreamFactory factory = new StreamFactory()
|
||||
.withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress())
|
||||
.withCollectionZkHost("destinationCollection", cluster.getZkServer().getZkAddress())
|
||||
.withFunctionName("featuresSelection", FeaturesSelectionStream.class)
|
||||
.withFunctionName("search", CloudSolrStream.class)
|
||||
.withFunctionName("update", UpdateStream.class);
|
||||
|
||||
String featuresExpression = "featuresSelection(collection1, q=\"*:*\", featureSet=\"first\", field=\"whitetok\", outcome=\"out_i\", numTerms=4)";
|
||||
// basic
|
||||
expression = StreamExpressionParser.parse(featuresExpression);
|
||||
stream = new FeaturesSelectionStream(expression, factory);
|
||||
tuples = getTuples(stream);
|
||||
|
||||
assert(tuples.size() == 4);
|
||||
|
||||
assertTrue(tuples.get(0).get("term_s").equals("c"));
|
||||
assertTrue(tuples.get(1).get("term_s").equals("d"));
|
||||
assertTrue(tuples.get(2).get("term_s").equals("e"));
|
||||
assertTrue(tuples.get(3).get("term_s").equals("f"));
|
||||
|
||||
// update
|
||||
expression = StreamExpressionParser.parse("update(destinationCollection, batchSize=5, "+featuresExpression+")");
|
||||
stream = new UpdateStream(expression, factory);
|
||||
getTuples(stream);
|
||||
cluster.getSolrClient().commit("destinationCollection");
|
||||
|
||||
expression = StreamExpressionParser.parse("search(destinationCollection, q=featureSet_s:first, fl=\"index_i, term_s\", sort=\"index_i asc\")");
|
||||
stream = new CloudSolrStream(expression, factory);
|
||||
tuples = getTuples(stream);
|
||||
assertEquals(4, tuples.size());
|
||||
assertTrue(tuples.get(0).get("term_s").equals("c"));
|
||||
assertTrue(tuples.get(1).get("term_s").equals("d"));
|
||||
assertTrue(tuples.get(2).get("term_s").equals("e"));
|
||||
assertTrue(tuples.get(3).get("term_s").equals("f"));
|
||||
|
||||
CollectionAdminRequest.deleteCollection("destinationCollection").process(cluster.getSolrClient());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComplementStream() throws Exception {
|
||||
|
||||
|
|
|
@ -62,6 +62,8 @@ public class StreamExpressionToExpessionTest extends LuceneTestCase {
|
|||
.withFunctionName("avg", MeanMetric.class)
|
||||
.withFunctionName("daemon", DaemonStream.class)
|
||||
.withFunctionName("topic", TopicStream.class)
|
||||
.withFunctionName("tlogit", TextLogitStream.class)
|
||||
.withFunctionName("featuresSelection", FeaturesSelectionStream.class)
|
||||
;
|
||||
}
|
||||
|
||||
|
@ -138,7 +140,6 @@ public class StreamExpressionToExpessionTest extends LuceneTestCase {
|
|||
assertTrue(expressionString.contains("checkpointEvery=1000"));
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testStatsStream() throws Exception {
|
||||
|
||||
|
@ -343,6 +344,40 @@ public class StreamExpressionToExpessionTest extends LuceneTestCase {
|
|||
assertTrue(secondExpressionString.contains("q=\"presentTitles:\\\"chief, executive officer\\\" AND age:[36 TO *]\""));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFeaturesSelectionStream() throws Exception {
|
||||
String expr = "featuresSelection(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4, positiveLabel=2)";
|
||||
FeaturesSelectionStream stream = new FeaturesSelectionStream(StreamExpressionParser.parse(expr), factory);
|
||||
String expressionString = stream.toExpression(factory).toString();
|
||||
assertTrue(expressionString.contains("q=\"*:*\""));
|
||||
assertTrue(expressionString.contains("featureSet=first"));
|
||||
assertTrue(expressionString.contains("field=tv_text"));
|
||||
assertTrue(expressionString.contains("outcome=out_i"));
|
||||
assertTrue(expressionString.contains("numTerms=4"));
|
||||
assertTrue(expressionString.contains("positiveLabel=2"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextLogitStreamWithFeaturesSelection() throws Exception {
|
||||
String expr = "tlogit(" +
|
||||
"collection1, " +
|
||||
"q=\"*:*\", " +
|
||||
"name=\"model\", " +
|
||||
"featuresSelection(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4), " +
|
||||
"field=\"tv_text\", " +
|
||||
"outcome=\"out_i\", " +
|
||||
"maxIterations=100)";
|
||||
TextLogitStream logitStream = new TextLogitStream(StreamExpressionParser.parse(expr), factory);
|
||||
String expressionString = logitStream.toExpression(factory).toString();
|
||||
assertTrue(expressionString.contains("q=\"*:*\""));
|
||||
assertTrue(expressionString.contains("name=model"));
|
||||
assertFalse(expressionString.contains("terms="));
|
||||
assertTrue(expressionString.contains("featuresSelection("));
|
||||
assertTrue(expressionString.contains("field=tv_text"));
|
||||
assertTrue(expressionString.contains("outcome=out_i"));
|
||||
assertTrue(expressionString.contains("maxIterations=100"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCountMetric() throws Exception {
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package org.apache.solr.client.solrj.io.stream;
|
||||
|
||||
import junit.framework.Assert;
|
||||
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.apache.solr.client.solrj.io.ops.GroupOperation;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
|
||||
|
|
Loading…
Reference in New Issue