SOLR-10696: Add cumulative probability function

This commit is contained in:
Joel Bernstein 2017-06-02 15:19:38 -04:00
parent f275e3b254
commit 99ca13f90f
4 changed files with 143 additions and 13 deletions

View File

@ -187,6 +187,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("copyOfRange", CopyOfRangeEvaluator.class) .withFunctionName("copyOfRange", CopyOfRangeEvaluator.class)
.withFunctionName("percentile", PercentileEvaluator.class) .withFunctionName("percentile", PercentileEvaluator.class)
.withFunctionName("empiricalDistribution", EmpiricalDistributionEvaluator.class) .withFunctionName("empiricalDistribution", EmpiricalDistributionEvaluator.class)
.withFunctionName("cumulativeProbability", CumulativeProbabilityEvaluator.class)
.withFunctionName("describe", DescribeEvaluator.class) .withFunctionName("describe", DescribeEvaluator.class)
.withFunctionName("finddelay", FindDelayEvaluator.class) .withFunctionName("finddelay", FindDelayEvaluator.class)
.withFunctionName("sequence", SequenceEvaluator.class) .withFunctionName("sequence", SequenceEvaluator.class)

View File

@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class CumulativeProbabilityEvaluator extends ComplexEvaluator implements Expressible {
private static final long serialVersionUID = 1;
public CumulativeProbabilityEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
}
public Number evaluate(Tuple tuple) throws IOException {
if(subEvaluators.size() != 2) {
throw new IOException("Cumulative probability expects 2 parameters: an emperical distribution and a number");
}
StreamEvaluator r = subEvaluators.get(0);
StreamEvaluator d = subEvaluators.get(1);
EmpiricalDistributionEvaluator.EmpiricalDistributionTuple e = (EmpiricalDistributionEvaluator.EmpiricalDistributionTuple)r.evaluate(tuple);
Number n = (Number)d.evaluate(tuple);
return e.percentile(n.doubleValue());
}
@Override
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
StreamExpression expression = new StreamExpression(factory.getFunctionName(getClass()));
return expression;
}
@Override
public Explanation toExplanation(StreamFactory factory) throws IOException {
return new Explanation(nodeId.toString())
.withExpressionType(ExpressionType.EVALUATOR)
.withFunctionName(factory.getFunctionName(getClass()))
.withImplementingClass(getClass().getName())
.withExpression(toExpression(factory).toString());
}
}

View File

@ -17,6 +17,9 @@
package org.apache.solr.client.solrj.io.eval; package org.apache.solr.client.solrj.io.eval;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.List;
import java.io.IOException; import java.io.IOException;
import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.Tuple;
@ -26,6 +29,7 @@ import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter; import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
public class PercentileEvaluator extends ComplexEvaluator implements Expressible { public class PercentileEvaluator extends ComplexEvaluator implements Expressible {
@ -36,17 +40,23 @@ public class PercentileEvaluator extends ComplexEvaluator implements Expressible
} }
public Number evaluate(Tuple tuple) throws IOException { public Number evaluate(Tuple tuple) throws IOException {
if(subEvaluators.size() != 2) { if(subEvaluators.size() != 2) {
throw new IOException("Percentile expects 2 parameters: a regression result and a number"); throw new IOException("Percentile expects 2 parameters: an array and a number");
} }
StreamEvaluator r = subEvaluators.get(0); StreamEvaluator colEval = subEvaluators.get(0);
StreamEvaluator d = subEvaluators.get(1); List<Number> column = (List<Number>)colEval.evaluate(tuple);
EmpiricalDistributionEvaluator.EmpiricalDistributionTuple e = (EmpiricalDistributionEvaluator.EmpiricalDistributionTuple)r.evaluate(tuple); double[] data = new double[column.size()];
Number n = (Number)d.evaluate(tuple); for(int i=0; i<data.length; i++) {
return e.percentile(n.doubleValue()); data[i] = column.get(i).doubleValue();
}
Percentile percentile = new Percentile();
percentile.setData(data);
StreamEvaluator numEval = subEvaluators.get(1);
Number num = (Number)numEval.evaluate(tuple);
return percentile.evaluate(num.doubleValue());
} }
@Override @Override

View File

@ -5759,8 +5759,61 @@ public class StreamExpressionTest extends SolrCloudTestCase {
} }
@Test @Test
public void testPercentiles() throws Exception { public void testPercentile() throws Exception {
String cexpr = "percentile(array(1,2,3,4,5,6,7,8,9,10,11), 50)";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Tuple tuple = tuples.get(0);
double p = tuple.getDouble("return-value");
assertEquals(p, 6, 0.0);
cexpr = "percentile(array(11,10,3,4,5,6,7,8,9,2,1), 50)";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
tuple = tuples.get(0);
p = tuple.getDouble("return-value");
assertEquals(p, 6, 0.0);
cexpr = "percentile(array(11,10,3,4,5,6,7,8,9,2,1), 20)";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
tuple = tuples.get(0);
p = tuple.getDouble("return-value");
assertEquals(p, 2.4, 0.001);
}
@Test
public void testCumulativeProbability() throws Exception {
UpdateRequest updateRequest = new UpdateRequest(); UpdateRequest updateRequest = new UpdateRequest();
int i=0; int i=0;
@ -5773,11 +5826,11 @@ public class StreamExpressionTest extends SolrCloudTestCase {
String expr = "search("+COLLECTIONORALIAS+", q=\"*:*\", fl=\"price_f\", sort=\"price_f asc\", rows=\"200\")"; String expr = "search("+COLLECTIONORALIAS+", q=\"*:*\", fl=\"price_f\", sort=\"price_f asc\", rows=\"200\")";
String cexpr = "let(a="+expr+", c=col(a, price_f), e=empiricalDistribution(c), " + String cexpr = "let(a="+expr+", c=col(a, price_f), e=empiricalDistribution(c), " +
"tuple(p1=percentile(e, 88), " + "tuple(p1=cumulativeProbability(e, 88), " +
"p2=percentile(e, 2), " + "p2=cumulativeProbability(e, 2), " +
"p3=percentile(e, 99), " + "p3=cumulativeProbability(e, 99), " +
"p4=percentile(e, 77), " + "p4=cumulativeProbability(e, 77), " +
"p5=percentile(e, 98)))"; "p5=cumulativeProbability(e, 98)))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr); paramsLoc.set("expr", cexpr);