SOLR-11241: Add discrete counting and probability Stream Evaluators

This commit is contained in:
Joel Bernstein 2017-09-06 11:02:24 -04:00
parent 1fe611d094
commit 4496612120
12 changed files with 509 additions and 20 deletions

View File

@ -94,7 +94,7 @@ org.apache.calcite.version = 1.13.0
/org.apache.commons/commons-compress = 1.11
/org.apache.commons/commons-exec = 1.3
/org.apache.commons/commons-lang3 = 3.6
/org.apache.commons/commons-math3 = 3.4.1
/org.apache.commons/commons-math3 = 3.6.1
org.apache.curator.version = 2.8.0
/org.apache.curator/curator-client = ${org.apache.curator.version}

View File

@ -29,6 +29,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.solr.client.solrj.io.ModelCache;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
@ -249,6 +250,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("copyOfRange", CopyOfRangeEvaluator.class)
.withFunctionName("copyOf", CopyOfEvaluator.class)
.withFunctionName("cov", CovarianceEvaluator.class)
.withFunctionName("corr", CorrelationEvaluator.class)
.withFunctionName("describe", DescribeEvaluator.class)
.withFunctionName("distance", EuclideanDistanceEvaluator.class)
.withFunctionName("empiricalDistribution", EmpiricalDistributionEvaluator.class)
@ -281,9 +283,18 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("ebeDivide", EBEDivideEvaluator.class)
.withFunctionName("dotProduct", DotProductEvaluator.class)
.withFunctionName("cosineSimilarity", CosineSimilarityEvaluator.class)
.withFunctionName("freqTable", FrequencyTableEvaluator.class)
.withFunctionName("uniformIntegerDistribution", UniformIntegerDistributionEvaluator.class)
.withFunctionName("binomialDistribution", BinomialDistributionEvaluator.class)
.withFunctionName("poissonDistribution", PoissonDistributionEvaluator.class)
.withFunctionName("enumeratedDistribution", EnumeratedDistributionEvaluator.class)
.withFunctionName("probability", ProbabilityEvaluator.class)
// Boolean Stream Evaluators
.withFunctionName("and", AndEvaluator.class)
.withFunctionName("and", AndEvaluator.class)
.withFunctionName("eor", ExclusiveOrEvaluator.class)
.withFunctionName("eq", EqualToEvaluator.class)
.withFunctionName("gt", GreaterThanEvaluator.class)
@ -331,7 +342,6 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("cbrt", CubedRootEvaluator.class)
.withFunctionName("coalesce", CoalesceEvaluator.class)
.withFunctionName("uuid", UuidEvaluator.class)
.withFunctionName("corr", CorrelationEvaluator.class)
// Conditional Stream Evaluators
.withFunctionName("if", IfThenElseEvaluator.class)

View File

@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class BinomialDistributionEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
private static final long serialVersionUID = 1;
public BinomialDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
}
@Override
public Object doWork(Object first, Object second) throws IOException{
if(null == first){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
}
if(null == second){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
}
Number numberOfTrials = (Number)first;
Number successProb = (Number)second;
return new BinomialDistribution(numberOfTrials.intValue(), successProb.doubleValue());
}
}

View File

@ -20,6 +20,7 @@ import java.io.IOException;
import java.util.Locale;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
@ -38,17 +39,21 @@ public class CumulativeProbabilityEvaluator extends RecursiveObjectEvaluator imp
if(null == second){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
}
if(!(first instanceof RealDistribution)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a RealDistribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
if(!(first instanceof RealDistribution) && !(first instanceof IntegerDistribution)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a real or integer Distribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
}
if(!(second instanceof Number)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName()));
}
RealDistribution rd = (RealDistribution)first;
Number predictOver = (Number)second;
return rd.cumulativeProbability(predictOver.doubleValue());
if(first instanceof RealDistribution) {
RealDistribution rd = (RealDistribution) first;
Number predictOver = (Number) second;
return rd.cumulativeProbability(predictOver.doubleValue());
} else {
IntegerDistribution id = (IntegerDistribution) first;
Number predictOver = (Number) second;
return id.cumulativeProbability(predictOver.intValue());
}
}
}

View File

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.List;
import java.util.Locale;
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class EnumeratedDistributionEvaluator extends RecursiveNumericEvaluator implements OneValueWorker {
private static final long serialVersionUID = 1;
public EnumeratedDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
}
@Override
public Object doWork(Object first) throws IOException{
if(null == first){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
}
int[] samples = ((List)first).stream().mapToInt(value -> ((BigDecimal) value).intValue()).toArray();
return new EnumeratedIntegerDistribution(samples);
}
}

View File

@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.apache.commons.math3.random.EmpiricalDistribution;
import org.apache.commons.math3.stat.Frequency;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class FrequencyTableEvaluator extends RecursiveNumericEvaluator implements ManyValueWorker {
protected static final long serialVersionUID = 1L;
public FrequencyTableEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
if(containedEvaluators.size() < 1){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting at least one value but found %d",expression,containedEvaluators.size()));
}
}
@Override
public Object doWork(Object... values) throws IOException {
if(Arrays.stream(values).anyMatch(item -> null == item)){
return null;
}
List<?> sourceValues;
if(values.length == 1){
sourceValues = values[0] instanceof List<?> ? (List<?>)values[0] : Arrays.asList(values[0]);
}
else
{
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting at least one value but found %d",toExpression(constructingFactory),containedEvaluators.size()));
}
Frequency frequency = new Frequency();
for(Object o : sourceValues) {
Number number = (Number)o;
frequency.addValue(number.longValue());
}
List<Tuple> histogramBins = new ArrayList<>();
Iterator iterator = frequency.valuesIterator();
while(iterator.hasNext()){
Long value = (Long)iterator.next();
Map<String,Number> map = new HashMap<>();
map.put("value", value.longValue());
map.put("count", frequency.getCount(value));
map.put("cumFreq", frequency.getCumFreq(value));
map.put("cumPct", frequency.getCumPct(value));
map.put("pct", frequency.getPct(value));
histogramBins.add(new Tuple(map));
}
return histogramBins;
}
}

View File

@ -26,6 +26,7 @@ import java.util.Map;
import org.apache.commons.math3.random.EmpiricalDistribution;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
@ -68,7 +69,7 @@ public class HistogramEvaluator extends RecursiveNumericEvaluator implements Man
EmpiricalDistribution distribution = new EmpiricalDistribution(bins);
distribution.load(((List<?>)sourceValues).stream().mapToDouble(value -> ((Number)value).doubleValue()).toArray());;
List<Map<String,Number>> histogramBins = new ArrayList<>();
List<Tuple> histogramBins = new ArrayList<>();
for(SummaryStatistics binSummary : distribution.getBinStats()) {
Map<String,Number> map = new HashMap<>();
map.put("max", binSummary.getMax());
@ -78,7 +79,9 @@ public class HistogramEvaluator extends RecursiveNumericEvaluator implements Man
map.put("sum", binSummary.getSum());
map.put("N", binSummary.getN());
map.put("var", binSummary.getVariance());
histogramBins.add(map);
map.put("cumProb", distribution.cumulativeProbability(binSummary.getMean()));
map.put("prob", distribution.probability(binSummary.getMin(), binSummary.getMax()));
histogramBins.add(new Tuple(map));
}
return histogramBins;

View File

@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class PoissonDistributionEvaluator extends RecursiveNumericEvaluator implements OneValueWorker {
private static final long serialVersionUID = 1;
public PoissonDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
}
@Override
public Object doWork(Object first) throws IOException{
if(null == first){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
}
Number mean = (Number)first;
return new PoissonDistribution(mean.intValue());
}
}

View File

@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class ProbabilityEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
protected static final long serialVersionUID = 1L;
public ProbabilityEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
@Override
public Object doWork(Object first, Object second) throws IOException{
if(null == first){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
}
if(null == second){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
}
if(!(first instanceof IntegerDistribution)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a IntegerDistribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
}
if(!(second instanceof Number)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName()));
}
IntegerDistribution d = (IntegerDistribution) first;
Number predictOver = (Number) second;
return d.probability(predictOver.intValue());
}
}

View File

@ -22,6 +22,7 @@ import java.util.Arrays;
import java.util.Locale;
import java.util.stream.Collectors;
import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
@ -42,16 +43,19 @@ public class SampleEvaluator extends RecursiveObjectEvaluator implements TwoValu
if(null == second){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
}
if(!(first instanceof RealDistribution)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a RealDistribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
if(!(first instanceof RealDistribution) && !(first instanceof IntegerDistribution)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a Real or Integer Distribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
}
if(!(second instanceof Number)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName()));
}
RealDistribution realDistribution = (RealDistribution)first;
return Arrays.stream(realDistribution.sample(((Number)second).intValue())).mapToObj(item -> item).collect(Collectors.toList());
}
if(first instanceof RealDistribution) {
RealDistribution realDistribution = (RealDistribution) first;
return Arrays.stream(realDistribution.sample(((Number) second).intValue())).mapToObj(item -> item).collect(Collectors.toList());
} else {
IntegerDistribution integerDistribution = (IntegerDistribution) first;
return Arrays.stream(integerDistribution.sample(((Number) second).intValue())).mapToObj(item -> item).collect(Collectors.toList());
}
}
}

View File

@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.Locale;
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class UniformIntegerDistributionEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
private static final long serialVersionUID = 1;
public UniformIntegerDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
}
@Override
public Object doWork(Object first, Object second) throws IOException{
if(null == first){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
}
if(null == second){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
}
Number lower = (Number)first;
Number upper = (Number)second;
return new UniformIntegerDistribution(lower.intValue(), upper.intValue());
}
}

View File

@ -6155,6 +6155,48 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertTrue(out.get(5).intValue() == 2);
}
@Test
public void testFreqTable() throws Exception {
String cexpr = "freqTable(array(2,4,6,8,10,12,12,4,8,8,8,2))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
List<Map<String,Number>> out = (List<Map<String, Number>>)tuples.get(0).get("return-value");
assertTrue(out.size() == 6);
Map<String, Number> bucket = out.get(0);
assertEquals(bucket.get("value").longValue(), 2);
assertEquals(bucket.get("count").longValue(), 2);
bucket = out.get(1);
assertEquals(bucket.get("value").longValue(), 4);
assertEquals(bucket.get("count").longValue(), 2);
bucket = out.get(2);
assertEquals(bucket.get("value").longValue(), 6);
assertEquals(bucket.get("count").longValue(), 1);
bucket = out.get(3);
assertEquals(bucket.get("value").longValue(), 8);
assertEquals(bucket.get("count").longValue(), 4);
bucket = out.get(4);
assertEquals(bucket.get("value").longValue(), 10);
assertEquals(bucket.get("count").longValue(), 1);
bucket = out.get(5);
assertEquals(bucket.get("value").longValue(), 12);
assertEquals(bucket.get("count").longValue(), 2);
}
@Test
public void testCosineSimilarity() throws Exception {
String cexpr = "cosineSimilarity(array(2,4,6,8),array(1,1,3,4))";
@ -6171,7 +6213,109 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertTrue(cs.doubleValue() == 0.9838197164968291);
}
@Test
public void testPoissonDistribution() throws Exception {
String cexpr = "let(a=poissonDistribution(100)," +
" b=sample(a, 10000)," +
" tuple(d=describe(b), " +
" p=probability(a, 100), " +
" c=cumulativeProbability(a, 100)))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Map map = (Map)tuples.get(0).get("d");
Number mean = (Number)map.get("mean");
Number var = (Number)map.get("var");
//The mean and variance should be almost the same for poisson distribution
assertEquals(mean.doubleValue(), var.doubleValue(), 3.0);
Number prob = (Number)tuples.get(0).get("p");
assertEquals(prob.doubleValue(), 0.03986099680914713, 0.0);
Number cprob = (Number)tuples.get(0).get("c");
assertEquals(cprob.doubleValue(), 0.5265621985303708, 0.0);
}
@Test
public void testBinomialDistribution() throws Exception {
String cexpr = "let(a=binomialDistribution(100, .50)," +
" b=sample(a, 10000)," +
" tuple(d=describe(b), " +
" p=probability(a, 50), " +
" c=cumulativeProbability(a, 50)))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Number prob = (Number)tuples.get(0).get("p");
assertEquals(prob.doubleValue(),0.07958923738717877, 0.0);
Number cprob = (Number)tuples.get(0).get("c");
assertEquals(cprob.doubleValue(), 0.5397946186935851, 0.0);
}
@Test
public void testUniformIntegerDistribution() throws Exception {
String cexpr = "let(a=uniformIntegerDistribution(1, 10)," +
" b=sample(a, 10000)," +
" tuple(d=describe(b), " +
" p=probability(a, 5), " +
" c=cumulativeProbability(a, 5)))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Map map = (Map)tuples.get(0).get("d");
Number N = (Number)map.get("N");
assertEquals(N.intValue(), 10000);
Number prob = (Number)tuples.get(0).get("p");
assertEquals(prob.doubleValue(), 0.1, 0.0);
Number cprob = (Number)tuples.get(0).get("c");
assertEquals(cprob.doubleValue(), 0.5, 0.0);
}
@Test
public void testEnumeratedDistribution() throws Exception {
String cexpr = "let(a=uniformIntegerDistribution(1, 10)," +
" b=sample(a, 10000)," +
" c=enumeratedDistribution(b),"+
" tuple(d=describe(b), " +
" p=probability(c, 5), " +
" c=cumulativeProbability(c, 5)))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Map map = (Map)tuples.get(0).get("d");
Number N = (Number)map.get("N");
assertEquals(N.intValue(), 10000);
Number prob = (Number)tuples.get(0).get("p");
assertEquals(prob.doubleValue(), 0.1, 0.07);
Number cprob = (Number)tuples.get(0).get("c");
assertEquals(cprob.doubleValue(), 0.5, 0.07);
}
@Test
public void testDotProduct() throws Exception {