SOLR-11241: Add discrete counting and probability Stream Evaluators

This commit is contained in:
Joel Bernstein 2017-09-06 11:02:24 -04:00
parent 1fe611d094
commit 4496612120
12 changed files with 509 additions and 20 deletions

View File

@ -94,7 +94,7 @@ org.apache.calcite.version = 1.13.0
/org.apache.commons/commons-compress = 1.11 /org.apache.commons/commons-compress = 1.11
/org.apache.commons/commons-exec = 1.3 /org.apache.commons/commons-exec = 1.3
/org.apache.commons/commons-lang3 = 3.6 /org.apache.commons/commons-lang3 = 3.6
/org.apache.commons/commons-math3 = 3.4.1 /org.apache.commons/commons-math3 = 3.6.1
org.apache.curator.version = 2.8.0 org.apache.curator.version = 2.8.0
/org.apache.curator/curator-client = ${org.apache.curator.version} /org.apache.curator/curator-client = ${org.apache.curator.version}

View File

@ -29,6 +29,7 @@ import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.solr.client.solrj.io.ModelCache; import org.apache.solr.client.solrj.io.ModelCache;
import org.apache.solr.client.solrj.io.SolrClientCache; import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.Tuple;
@ -249,6 +250,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("copyOfRange", CopyOfRangeEvaluator.class) .withFunctionName("copyOfRange", CopyOfRangeEvaluator.class)
.withFunctionName("copyOf", CopyOfEvaluator.class) .withFunctionName("copyOf", CopyOfEvaluator.class)
.withFunctionName("cov", CovarianceEvaluator.class) .withFunctionName("cov", CovarianceEvaluator.class)
.withFunctionName("corr", CorrelationEvaluator.class)
.withFunctionName("describe", DescribeEvaluator.class) .withFunctionName("describe", DescribeEvaluator.class)
.withFunctionName("distance", EuclideanDistanceEvaluator.class) .withFunctionName("distance", EuclideanDistanceEvaluator.class)
.withFunctionName("empiricalDistribution", EmpiricalDistributionEvaluator.class) .withFunctionName("empiricalDistribution", EmpiricalDistributionEvaluator.class)
@ -281,8 +283,17 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("ebeDivide", EBEDivideEvaluator.class) .withFunctionName("ebeDivide", EBEDivideEvaluator.class)
.withFunctionName("dotProduct", DotProductEvaluator.class) .withFunctionName("dotProduct", DotProductEvaluator.class)
.withFunctionName("cosineSimilarity", CosineSimilarityEvaluator.class) .withFunctionName("cosineSimilarity", CosineSimilarityEvaluator.class)
.withFunctionName("freqTable", FrequencyTableEvaluator.class)
.withFunctionName("uniformIntegerDistribution", UniformIntegerDistributionEvaluator.class)
.withFunctionName("binomialDistribution", BinomialDistributionEvaluator.class)
.withFunctionName("poissonDistribution", PoissonDistributionEvaluator.class)
.withFunctionName("enumeratedDistribution", EnumeratedDistributionEvaluator.class)
.withFunctionName("probability", ProbabilityEvaluator.class)
// Boolean Stream Evaluators // Boolean Stream Evaluators
.withFunctionName("and", AndEvaluator.class) .withFunctionName("and", AndEvaluator.class)
.withFunctionName("eor", ExclusiveOrEvaluator.class) .withFunctionName("eor", ExclusiveOrEvaluator.class)
.withFunctionName("eq", EqualToEvaluator.class) .withFunctionName("eq", EqualToEvaluator.class)
@ -331,7 +342,6 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("cbrt", CubedRootEvaluator.class) .withFunctionName("cbrt", CubedRootEvaluator.class)
.withFunctionName("coalesce", CoalesceEvaluator.class) .withFunctionName("coalesce", CoalesceEvaluator.class)
.withFunctionName("uuid", UuidEvaluator.class) .withFunctionName("uuid", UuidEvaluator.class)
.withFunctionName("corr", CorrelationEvaluator.class)
// Conditional Stream Evaluators // Conditional Stream Evaluators
.withFunctionName("if", IfThenElseEvaluator.class) .withFunctionName("if", IfThenElseEvaluator.class)

View File

@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class BinomialDistributionEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
private static final long serialVersionUID = 1;
public BinomialDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
}
@Override
public Object doWork(Object first, Object second) throws IOException{
if(null == first){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
}
if(null == second){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
}
Number numberOfTrials = (Number)first;
Number successProb = (Number)second;
return new BinomialDistribution(numberOfTrials.intValue(), successProb.doubleValue());
}
}

View File

@ -20,6 +20,7 @@ import java.io.IOException;
import java.util.Locale; import java.util.Locale;
import org.apache.commons.math3.distribution.RealDistribution; import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
@ -38,17 +39,21 @@ public class CumulativeProbabilityEvaluator extends RecursiveObjectEvaluator imp
if(null == second){ if(null == second){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
} }
if(!(first instanceof RealDistribution)){ if(!(first instanceof RealDistribution) && !(first instanceof IntegerDistribution)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a RealDistribution",toExpression(constructingFactory), first.getClass().getSimpleName())); throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a real or integer Distribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
} }
if(!(second instanceof Number)){ if(!(second instanceof Number)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName())); throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName()));
} }
if(first instanceof RealDistribution) {
RealDistribution rd = (RealDistribution) first; RealDistribution rd = (RealDistribution) first;
Number predictOver = (Number) second; Number predictOver = (Number) second;
return rd.cumulativeProbability(predictOver.doubleValue()); return rd.cumulativeProbability(predictOver.doubleValue());
} else {
IntegerDistribution id = (IntegerDistribution) first;
Number predictOver = (Number) second;
return id.cumulativeProbability(predictOver.intValue());
}
} }
} }

View File

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.List;
import java.util.Locale;
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class EnumeratedDistributionEvaluator extends RecursiveNumericEvaluator implements OneValueWorker {
private static final long serialVersionUID = 1;
public EnumeratedDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
}
@Override
public Object doWork(Object first) throws IOException{
if(null == first){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
}
int[] samples = ((List)first).stream().mapToInt(value -> ((BigDecimal) value).intValue()).toArray();
return new EnumeratedIntegerDistribution(samples);
}
}

View File

@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.apache.commons.math3.random.EmpiricalDistribution;
import org.apache.commons.math3.stat.Frequency;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class FrequencyTableEvaluator extends RecursiveNumericEvaluator implements ManyValueWorker {
protected static final long serialVersionUID = 1L;
public FrequencyTableEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
if(containedEvaluators.size() < 1){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting at least one value but found %d",expression,containedEvaluators.size()));
}
}
@Override
public Object doWork(Object... values) throws IOException {
if(Arrays.stream(values).anyMatch(item -> null == item)){
return null;
}
List<?> sourceValues;
if(values.length == 1){
sourceValues = values[0] instanceof List<?> ? (List<?>)values[0] : Arrays.asList(values[0]);
}
else
{
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting at least one value but found %d",toExpression(constructingFactory),containedEvaluators.size()));
}
Frequency frequency = new Frequency();
for(Object o : sourceValues) {
Number number = (Number)o;
frequency.addValue(number.longValue());
}
List<Tuple> histogramBins = new ArrayList<>();
Iterator iterator = frequency.valuesIterator();
while(iterator.hasNext()){
Long value = (Long)iterator.next();
Map<String,Number> map = new HashMap<>();
map.put("value", value.longValue());
map.put("count", frequency.getCount(value));
map.put("cumFreq", frequency.getCumFreq(value));
map.put("cumPct", frequency.getCumPct(value));
map.put("pct", frequency.getPct(value));
histogramBins.add(new Tuple(map));
}
return histogramBins;
}
}

View File

@ -26,6 +26,7 @@ import java.util.Map;
import org.apache.commons.math3.random.EmpiricalDistribution; import org.apache.commons.math3.random.EmpiricalDistribution;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics; import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
@ -68,7 +69,7 @@ public class HistogramEvaluator extends RecursiveNumericEvaluator implements Man
EmpiricalDistribution distribution = new EmpiricalDistribution(bins); EmpiricalDistribution distribution = new EmpiricalDistribution(bins);
distribution.load(((List<?>)sourceValues).stream().mapToDouble(value -> ((Number)value).doubleValue()).toArray());; distribution.load(((List<?>)sourceValues).stream().mapToDouble(value -> ((Number)value).doubleValue()).toArray());;
List<Map<String,Number>> histogramBins = new ArrayList<>(); List<Tuple> histogramBins = new ArrayList<>();
for(SummaryStatistics binSummary : distribution.getBinStats()) { for(SummaryStatistics binSummary : distribution.getBinStats()) {
Map<String,Number> map = new HashMap<>(); Map<String,Number> map = new HashMap<>();
map.put("max", binSummary.getMax()); map.put("max", binSummary.getMax());
@ -78,7 +79,9 @@ public class HistogramEvaluator extends RecursiveNumericEvaluator implements Man
map.put("sum", binSummary.getSum()); map.put("sum", binSummary.getSum());
map.put("N", binSummary.getN()); map.put("N", binSummary.getN());
map.put("var", binSummary.getVariance()); map.put("var", binSummary.getVariance());
histogramBins.add(map); map.put("cumProb", distribution.cumulativeProbability(binSummary.getMean()));
map.put("prob", distribution.probability(binSummary.getMin(), binSummary.getMax()));
histogramBins.add(new Tuple(map));
} }
return histogramBins; return histogramBins;

View File

@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class PoissonDistributionEvaluator extends RecursiveNumericEvaluator implements OneValueWorker {
private static final long serialVersionUID = 1;
public PoissonDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
}
@Override
public Object doWork(Object first) throws IOException{
if(null == first){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
}
Number mean = (Number)first;
return new PoissonDistribution(mean.intValue());
}
}

View File

@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class ProbabilityEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
protected static final long serialVersionUID = 1L;
public ProbabilityEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
@Override
public Object doWork(Object first, Object second) throws IOException{
if(null == first){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
}
if(null == second){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
}
if(!(first instanceof IntegerDistribution)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a IntegerDistribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
}
if(!(second instanceof Number)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName()));
}
IntegerDistribution d = (IntegerDistribution) first;
Number predictOver = (Number) second;
return d.probability(predictOver.intValue());
}
}

View File

@ -22,6 +22,7 @@ import java.util.Arrays;
import java.util.Locale; import java.util.Locale;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.commons.math3.distribution.RealDistribution; import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
@ -42,16 +43,19 @@ public class SampleEvaluator extends RecursiveObjectEvaluator implements TwoValu
if(null == second){ if(null == second){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
} }
if(!(first instanceof RealDistribution)){ if(!(first instanceof RealDistribution) && !(first instanceof IntegerDistribution)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a RealDistribution",toExpression(constructingFactory), first.getClass().getSimpleName())); throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a Real or Integer Distribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
} }
if(!(second instanceof Number)){ if(!(second instanceof Number)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName())); throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName()));
} }
if(first instanceof RealDistribution) {
RealDistribution realDistribution = (RealDistribution) first; RealDistribution realDistribution = (RealDistribution) first;
return Arrays.stream(realDistribution.sample(((Number) second).intValue())).mapToObj(item -> item).collect(Collectors.toList()); return Arrays.stream(realDistribution.sample(((Number) second).intValue())).mapToObj(item -> item).collect(Collectors.toList());
} else {
IntegerDistribution integerDistribution = (IntegerDistribution) first;
return Arrays.stream(integerDistribution.sample(((Number) second).intValue())).mapToObj(item -> item).collect(Collectors.toList());
}
} }
} }

View File

@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class UniformIntegerDistributionEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
private static final long serialVersionUID = 1;
public UniformIntegerDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
}
@Override
public Object doWork(Object first, Object second) throws IOException{
if(null == first){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
}
if(null == second){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
}
Number lower = (Number)first;
Number upper = (Number)second;
return new UniformIntegerDistribution(lower.intValue(), upper.intValue());
}
}

View File

@ -6155,6 +6155,48 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertTrue(out.get(5).intValue() == 2); assertTrue(out.get(5).intValue() == 2);
} }
@Test
public void testFreqTable() throws Exception {
String cexpr = "freqTable(array(2,4,6,8,10,12,12,4,8,8,8,2))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
List<Map<String,Number>> out = (List<Map<String, Number>>)tuples.get(0).get("return-value");
assertTrue(out.size() == 6);
Map<String, Number> bucket = out.get(0);
assertEquals(bucket.get("value").longValue(), 2);
assertEquals(bucket.get("count").longValue(), 2);
bucket = out.get(1);
assertEquals(bucket.get("value").longValue(), 4);
assertEquals(bucket.get("count").longValue(), 2);
bucket = out.get(2);
assertEquals(bucket.get("value").longValue(), 6);
assertEquals(bucket.get("count").longValue(), 1);
bucket = out.get(3);
assertEquals(bucket.get("value").longValue(), 8);
assertEquals(bucket.get("count").longValue(), 4);
bucket = out.get(4);
assertEquals(bucket.get("value").longValue(), 10);
assertEquals(bucket.get("count").longValue(), 1);
bucket = out.get(5);
assertEquals(bucket.get("value").longValue(), 12);
assertEquals(bucket.get("count").longValue(), 2);
}
@Test @Test
public void testCosineSimilarity() throws Exception { public void testCosineSimilarity() throws Exception {
String cexpr = "cosineSimilarity(array(2,4,6,8),array(1,1,3,4))"; String cexpr = "cosineSimilarity(array(2,4,6,8),array(1,1,3,4))";
@ -6171,7 +6213,109 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertTrue(cs.doubleValue() == 0.9838197164968291); assertTrue(cs.doubleValue() == 0.9838197164968291);
} }
@Test
public void testPoissonDistribution() throws Exception {
String cexpr = "let(a=poissonDistribution(100)," +
" b=sample(a, 10000)," +
" tuple(d=describe(b), " +
" p=probability(a, 100), " +
" c=cumulativeProbability(a, 100)))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Map map = (Map)tuples.get(0).get("d");
Number mean = (Number)map.get("mean");
Number var = (Number)map.get("var");
//The mean and variance should be almost the same for poisson distribution
assertEquals(mean.doubleValue(), var.doubleValue(), 3.0);
Number prob = (Number)tuples.get(0).get("p");
assertEquals(prob.doubleValue(), 0.03986099680914713, 0.0);
Number cprob = (Number)tuples.get(0).get("c");
assertEquals(cprob.doubleValue(), 0.5265621985303708, 0.0);
}
@Test
public void testBinomialDistribution() throws Exception {
String cexpr = "let(a=binomialDistribution(100, .50)," +
" b=sample(a, 10000)," +
" tuple(d=describe(b), " +
" p=probability(a, 50), " +
" c=cumulativeProbability(a, 50)))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Number prob = (Number)tuples.get(0).get("p");
assertEquals(prob.doubleValue(),0.07958923738717877, 0.0);
Number cprob = (Number)tuples.get(0).get("c");
assertEquals(cprob.doubleValue(), 0.5397946186935851, 0.0);
}
@Test
public void testUniformIntegerDistribution() throws Exception {
String cexpr = "let(a=uniformIntegerDistribution(1, 10)," +
" b=sample(a, 10000)," +
" tuple(d=describe(b), " +
" p=probability(a, 5), " +
" c=cumulativeProbability(a, 5)))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Map map = (Map)tuples.get(0).get("d");
Number N = (Number)map.get("N");
assertEquals(N.intValue(), 10000);
Number prob = (Number)tuples.get(0).get("p");
assertEquals(prob.doubleValue(), 0.1, 0.0);
Number cprob = (Number)tuples.get(0).get("c");
assertEquals(cprob.doubleValue(), 0.5, 0.0);
}
@Test
public void testEnumeratedDistribution() throws Exception {
String cexpr = "let(a=uniformIntegerDistribution(1, 10)," +
" b=sample(a, 10000)," +
" c=enumeratedDistribution(b),"+
" tuple(d=describe(b), " +
" p=probability(c, 5), " +
" c=cumulativeProbability(c, 5)))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Map map = (Map)tuples.get(0).get("d");
Number N = (Number)map.get("N");
assertEquals(N.intValue(), 10000);
Number prob = (Number)tuples.get(0).get("p");
assertEquals(prob.doubleValue(), 0.1, 0.07);
Number cprob = (Number)tuples.get(0).get("c");
assertEquals(cprob.doubleValue(), 0.5, 0.07);
}
@Test @Test
public void testDotProduct() throws Exception { public void testDotProduct() throws Exception {