From 4496612120e361fc9cf2df27115297ceb35a81cc Mon Sep 17 00:00:00 2001 From: Joel Bernstein Date: Wed, 6 Sep 2017 11:02:24 -0400 Subject: [PATCH] SOLR-11241: Add discrete counting and probability Stream Evaluators --- lucene/ivy-versions.properties | 2 +- .../apache/solr/handler/StreamHandler.java | 14 +- .../eval/BinomialDistributionEvaluator.java | 48 ++++++ .../eval/CumulativeProbabilityEvaluator.java | 21 ++- .../eval/EnumeratedDistributionEvaluator.java | 45 ++++++ .../io/eval/FrequencyTableEvaluator.java | 86 +++++++++++ .../solrj/io/eval/HistogramEvaluator.java | 7 +- .../io/eval/PoissonDistributionEvaluator.java | 44 ++++++ .../solrj/io/eval/ProbabilityEvaluator.java | 52 +++++++ .../client/solrj/io/eval/SampleEvaluator.java | 18 ++- .../UniformIntegerDistributionEvaluator.java | 48 ++++++ .../solrj/io/stream/StreamExpressionTest.java | 144 ++++++++++++++++++ 12 files changed, 509 insertions(+), 20 deletions(-) create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/BinomialDistributionEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EnumeratedDistributionEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/FrequencyTableEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PoissonDistributionEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ProbabilityEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/UniformIntegerDistributionEvaluator.java diff --git a/lucene/ivy-versions.properties b/lucene/ivy-versions.properties index 66b7c89f9f8..8c21e99f8a5 100644 --- a/lucene/ivy-versions.properties +++ b/lucene/ivy-versions.properties @@ -94,7 +94,7 @@ org.apache.calcite.version = 1.13.0 /org.apache.commons/commons-compress = 1.11 /org.apache.commons/commons-exec = 1.3 /org.apache.commons/commons-lang3 = 3.6 -/org.apache.commons/commons-math3 = 3.4.1 +/org.apache.commons/commons-math3 = 3.6.1 org.apache.curator.version = 2.8.0 /org.apache.curator/curator-client = ${org.apache.curator.version} diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java index 9613ec33a06..5396acd9f04 100644 --- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java +++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java @@ -29,6 +29,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import org.apache.commons.math3.distribution.PoissonDistribution; import org.apache.solr.client.solrj.io.ModelCache; import org.apache.solr.client.solrj.io.SolrClientCache; import org.apache.solr.client.solrj.io.Tuple; @@ -249,6 +250,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, .withFunctionName("copyOfRange", CopyOfRangeEvaluator.class) .withFunctionName("copyOf", CopyOfEvaluator.class) .withFunctionName("cov", CovarianceEvaluator.class) + .withFunctionName("corr", CorrelationEvaluator.class) .withFunctionName("describe", DescribeEvaluator.class) .withFunctionName("distance", EuclideanDistanceEvaluator.class) .withFunctionName("empiricalDistribution", EmpiricalDistributionEvaluator.class) @@ -281,9 +283,18 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, .withFunctionName("ebeDivide", EBEDivideEvaluator.class) .withFunctionName("dotProduct", DotProductEvaluator.class) .withFunctionName("cosineSimilarity", CosineSimilarityEvaluator.class) + .withFunctionName("freqTable", FrequencyTableEvaluator.class) + .withFunctionName("uniformIntegerDistribution", UniformIntegerDistributionEvaluator.class) + .withFunctionName("binomialDistribution", BinomialDistributionEvaluator.class) + .withFunctionName("poissonDistribution", PoissonDistributionEvaluator.class) + .withFunctionName("enumeratedDistribution", EnumeratedDistributionEvaluator.class) + .withFunctionName("probability", ProbabilityEvaluator.class) // Boolean Stream Evaluators - .withFunctionName("and", AndEvaluator.class) + + + + .withFunctionName("and", AndEvaluator.class) .withFunctionName("eor", ExclusiveOrEvaluator.class) .withFunctionName("eq", EqualToEvaluator.class) .withFunctionName("gt", GreaterThanEvaluator.class) @@ -331,7 +342,6 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, .withFunctionName("cbrt", CubedRootEvaluator.class) .withFunctionName("coalesce", CoalesceEvaluator.class) .withFunctionName("uuid", UuidEvaluator.class) - .withFunctionName("corr", CorrelationEvaluator.class) // Conditional Stream Evaluators .withFunctionName("if", IfThenElseEvaluator.class) diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/BinomialDistributionEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/BinomialDistributionEvaluator.java new file mode 100644 index 00000000000..a0a6fb7664d --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/BinomialDistributionEvaluator.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.Locale; + +import org.apache.commons.math3.distribution.BinomialDistribution; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class BinomialDistributionEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker { + + private static final long serialVersionUID = 1; + + public BinomialDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException { + super(expression, factory); + } + + @Override + public Object doWork(Object first, Object second) throws IOException{ + if(null == first){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory))); + } + if(null == second){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); + } + + Number numberOfTrials = (Number)first; + Number successProb = (Number)second; + + return new BinomialDistribution(numberOfTrials.intValue(), successProb.doubleValue()); + } +} \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CumulativeProbabilityEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CumulativeProbabilityEvaluator.java index 8ecf35da9f7..a7e6d5a2017 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CumulativeProbabilityEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CumulativeProbabilityEvaluator.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.Locale; import org.apache.commons.math3.distribution.RealDistribution; +import org.apache.commons.math3.distribution.IntegerDistribution; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; @@ -38,17 +39,21 @@ public class CumulativeProbabilityEvaluator extends RecursiveObjectEvaluator imp if(null == second){ throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); } - if(!(first instanceof RealDistribution)){ - throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a RealDistribution",toExpression(constructingFactory), first.getClass().getSimpleName())); + if(!(first instanceof RealDistribution) && !(first instanceof IntegerDistribution)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a real or integer Distribution",toExpression(constructingFactory), first.getClass().getSimpleName())); } if(!(second instanceof Number)){ throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName())); } - - RealDistribution rd = (RealDistribution)first; - Number predictOver = (Number)second; - - return rd.cumulativeProbability(predictOver.doubleValue()); + + if(first instanceof RealDistribution) { + RealDistribution rd = (RealDistribution) first; + Number predictOver = (Number) second; + return rd.cumulativeProbability(predictOver.doubleValue()); + } else { + IntegerDistribution id = (IntegerDistribution) first; + Number predictOver = (Number) second; + return id.cumulativeProbability(predictOver.intValue()); + } } - } diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EnumeratedDistributionEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EnumeratedDistributionEvaluator.java new file mode 100644 index 00000000000..a14e54b43fc --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EnumeratedDistributionEvaluator.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.List; +import java.util.Locale; + +import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class EnumeratedDistributionEvaluator extends RecursiveNumericEvaluator implements OneValueWorker { + + private static final long serialVersionUID = 1; + + public EnumeratedDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException { + super(expression, factory); + } + + @Override + public Object doWork(Object first) throws IOException{ + if(null == first){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory))); + } + + int[] samples = ((List)first).stream().mapToInt(value -> ((BigDecimal) value).intValue()).toArray(); + return new EnumeratedIntegerDistribution(samples); + } +} \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/FrequencyTableEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/FrequencyTableEvaluator.java new file mode 100644 index 00000000000..ae65dc1e33d --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/FrequencyTableEvaluator.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.apache.commons.math3.random.EmpiricalDistribution; +import org.apache.commons.math3.stat.Frequency; + +import org.apache.commons.math3.stat.descriptive.SummaryStatistics; +import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class FrequencyTableEvaluator extends RecursiveNumericEvaluator implements ManyValueWorker { + protected static final long serialVersionUID = 1L; + + public FrequencyTableEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + + if(containedEvaluators.size() < 1){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting at least one value but found %d",expression,containedEvaluators.size())); + } + } + + @Override + public Object doWork(Object... values) throws IOException { + if(Arrays.stream(values).anyMatch(item -> null == item)){ + return null; + } + + List sourceValues; + + if(values.length == 1){ + sourceValues = values[0] instanceof List ? (List)values[0] : Arrays.asList(values[0]); + } + else + { + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting at least one value but found %d",toExpression(constructingFactory),containedEvaluators.size())); + } + + Frequency frequency = new Frequency(); + + for(Object o : sourceValues) { + Number number = (Number)o; + frequency.addValue(number.longValue()); + } + + List histogramBins = new ArrayList<>(); + + Iterator iterator = frequency.valuesIterator(); + + while(iterator.hasNext()){ + Long value = (Long)iterator.next(); + Map map = new HashMap<>(); + map.put("value", value.longValue()); + map.put("count", frequency.getCount(value)); + map.put("cumFreq", frequency.getCumFreq(value)); + map.put("cumPct", frequency.getCumPct(value)); + map.put("pct", frequency.getPct(value)); + histogramBins.add(new Tuple(map)); + } + return histogramBins; + } +} diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/HistogramEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/HistogramEvaluator.java index f58f319f650..8d2761469f9 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/HistogramEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/HistogramEvaluator.java @@ -26,6 +26,7 @@ import java.util.Map; import org.apache.commons.math3.random.EmpiricalDistribution; import org.apache.commons.math3.stat.descriptive.SummaryStatistics; +import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; @@ -68,7 +69,7 @@ public class HistogramEvaluator extends RecursiveNumericEvaluator implements Man EmpiricalDistribution distribution = new EmpiricalDistribution(bins); distribution.load(((List)sourceValues).stream().mapToDouble(value -> ((Number)value).doubleValue()).toArray());; - List> histogramBins = new ArrayList<>(); + List histogramBins = new ArrayList<>(); for(SummaryStatistics binSummary : distribution.getBinStats()) { Map map = new HashMap<>(); map.put("max", binSummary.getMax()); @@ -78,7 +79,9 @@ public class HistogramEvaluator extends RecursiveNumericEvaluator implements Man map.put("sum", binSummary.getSum()); map.put("N", binSummary.getN()); map.put("var", binSummary.getVariance()); - histogramBins.add(map); + map.put("cumProb", distribution.cumulativeProbability(binSummary.getMean())); + map.put("prob", distribution.probability(binSummary.getMin(), binSummary.getMax())); + histogramBins.add(new Tuple(map)); } return histogramBins; diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PoissonDistributionEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PoissonDistributionEvaluator.java new file mode 100644 index 00000000000..864a039d5ad --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PoissonDistributionEvaluator.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.Locale; + +import org.apache.commons.math3.distribution.PoissonDistribution; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class PoissonDistributionEvaluator extends RecursiveNumericEvaluator implements OneValueWorker { + + private static final long serialVersionUID = 1; + + public PoissonDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException { + super(expression, factory); + } + + @Override + public Object doWork(Object first) throws IOException{ + if(null == first){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory))); + } + + Number mean = (Number)first; + + return new PoissonDistribution(mean.intValue()); + } +} \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ProbabilityEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ProbabilityEvaluator.java new file mode 100644 index 00000000000..f0c25cb5846 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ProbabilityEvaluator.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.Locale; + +import org.apache.commons.math3.distribution.IntegerDistribution; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class ProbabilityEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker { + protected static final long serialVersionUID = 1L; + + public ProbabilityEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + } + + @Override + public Object doWork(Object first, Object second) throws IOException{ + if(null == first){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory))); + } + if(null == second){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); + } + if(!(first instanceof IntegerDistribution)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a IntegerDistribution",toExpression(constructingFactory), first.getClass().getSimpleName())); + } + if(!(second instanceof Number)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName())); + } + + IntegerDistribution d = (IntegerDistribution) first; + Number predictOver = (Number) second; + return d.probability(predictOver.intValue()); + } +} diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SampleEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SampleEvaluator.java index 5732f8533b3..be983055d62 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SampleEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SampleEvaluator.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.Locale; import java.util.stream.Collectors; +import org.apache.commons.math3.distribution.IntegerDistribution; import org.apache.commons.math3.distribution.RealDistribution; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; @@ -42,16 +43,19 @@ public class SampleEvaluator extends RecursiveObjectEvaluator implements TwoValu if(null == second){ throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); } - if(!(first instanceof RealDistribution)){ - throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a RealDistribution",toExpression(constructingFactory), first.getClass().getSimpleName())); + if(!(first instanceof RealDistribution) && !(first instanceof IntegerDistribution)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a Real or Integer Distribution",toExpression(constructingFactory), first.getClass().getSimpleName())); } if(!(second instanceof Number)){ throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName())); } - - RealDistribution realDistribution = (RealDistribution)first; - - return Arrays.stream(realDistribution.sample(((Number)second).intValue())).mapToObj(item -> item).collect(Collectors.toList()); - } + if(first instanceof RealDistribution) { + RealDistribution realDistribution = (RealDistribution) first; + return Arrays.stream(realDistribution.sample(((Number) second).intValue())).mapToObj(item -> item).collect(Collectors.toList()); + } else { + IntegerDistribution integerDistribution = (IntegerDistribution) first; + return Arrays.stream(integerDistribution.sample(((Number) second).intValue())).mapToObj(item -> item).collect(Collectors.toList()); + } + } } \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/UniformIntegerDistributionEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/UniformIntegerDistributionEvaluator.java new file mode 100644 index 00000000000..adf1aff4abe --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/UniformIntegerDistributionEvaluator.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.Locale; + +import org.apache.commons.math3.distribution.UniformIntegerDistribution; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class UniformIntegerDistributionEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker { + + private static final long serialVersionUID = 1; + + public UniformIntegerDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException { + super(expression, factory); + } + + @Override + public Object doWork(Object first, Object second) throws IOException{ + if(null == first){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory))); + } + if(null == second){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); + } + + Number lower = (Number)first; + Number upper = (Number)second; + + return new UniformIntegerDistribution(lower.intValue(), upper.intValue()); + } +} \ No newline at end of file diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java index f831daca740..fafac6ff73c 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java @@ -6155,6 +6155,48 @@ public class StreamExpressionTest extends SolrCloudTestCase { assertTrue(out.get(5).intValue() == 2); } + + @Test + public void testFreqTable() throws Exception { + String cexpr = "freqTable(array(2,4,6,8,10,12,12,4,8,8,8,2))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + List> out = (List>)tuples.get(0).get("return-value"); + assertTrue(out.size() == 6); + Map bucket = out.get(0); + assertEquals(bucket.get("value").longValue(), 2); + assertEquals(bucket.get("count").longValue(), 2); + + bucket = out.get(1); + assertEquals(bucket.get("value").longValue(), 4); + assertEquals(bucket.get("count").longValue(), 2); + + bucket = out.get(2); + assertEquals(bucket.get("value").longValue(), 6); + assertEquals(bucket.get("count").longValue(), 1); + + bucket = out.get(3); + assertEquals(bucket.get("value").longValue(), 8); + assertEquals(bucket.get("count").longValue(), 4); + + bucket = out.get(4); + assertEquals(bucket.get("value").longValue(), 10); + assertEquals(bucket.get("count").longValue(), 1); + + bucket = out.get(5); + assertEquals(bucket.get("value").longValue(), 12); + assertEquals(bucket.get("count").longValue(), 2); + } + + + @Test public void testCosineSimilarity() throws Exception { String cexpr = "cosineSimilarity(array(2,4,6,8),array(1,1,3,4))"; @@ -6171,7 +6213,109 @@ public class StreamExpressionTest extends SolrCloudTestCase { assertTrue(cs.doubleValue() == 0.9838197164968291); } + @Test + public void testPoissonDistribution() throws Exception { + String cexpr = "let(a=poissonDistribution(100)," + + " b=sample(a, 10000)," + + " tuple(d=describe(b), " + + " p=probability(a, 100), " + + " c=cumulativeProbability(a, 100)))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + Map map = (Map)tuples.get(0).get("d"); + Number mean = (Number)map.get("mean"); + Number var = (Number)map.get("var"); + //The mean and variance should be almost the same for poisson distribution + assertEquals(mean.doubleValue(), var.doubleValue(), 3.0); + Number prob = (Number)tuples.get(0).get("p"); + assertEquals(prob.doubleValue(), 0.03986099680914713, 0.0); + Number cprob = (Number)tuples.get(0).get("c"); + assertEquals(cprob.doubleValue(), 0.5265621985303708, 0.0); + } + + @Test + public void testBinomialDistribution() throws Exception { + String cexpr = "let(a=binomialDistribution(100, .50)," + + " b=sample(a, 10000)," + + " tuple(d=describe(b), " + + " p=probability(a, 50), " + + " c=cumulativeProbability(a, 50)))"; + + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + Number prob = (Number)tuples.get(0).get("p"); + assertEquals(prob.doubleValue(),0.07958923738717877, 0.0); + Number cprob = (Number)tuples.get(0).get("c"); + assertEquals(cprob.doubleValue(), 0.5397946186935851, 0.0); + } + + @Test + public void testUniformIntegerDistribution() throws Exception { + String cexpr = "let(a=uniformIntegerDistribution(1, 10)," + + " b=sample(a, 10000)," + + " tuple(d=describe(b), " + + " p=probability(a, 5), " + + " c=cumulativeProbability(a, 5)))"; + + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + Map map = (Map)tuples.get(0).get("d"); + Number N = (Number)map.get("N"); + assertEquals(N.intValue(), 10000); + Number prob = (Number)tuples.get(0).get("p"); + assertEquals(prob.doubleValue(), 0.1, 0.0); + Number cprob = (Number)tuples.get(0).get("c"); + assertEquals(cprob.doubleValue(), 0.5, 0.0); + } + + @Test + public void testEnumeratedDistribution() throws Exception { + String cexpr = "let(a=uniformIntegerDistribution(1, 10)," + + " b=sample(a, 10000)," + + " c=enumeratedDistribution(b),"+ + " tuple(d=describe(b), " + + " p=probability(c, 5), " + + " c=cumulativeProbability(c, 5)))"; + + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + Map map = (Map)tuples.get(0).get("d"); + Number N = (Number)map.get("N"); + assertEquals(N.intValue(), 10000); + Number prob = (Number)tuples.get(0).get("p"); + assertEquals(prob.doubleValue(), 0.1, 0.07); + Number cprob = (Number)tuples.get(0).get("c"); + assertEquals(cprob.doubleValue(), 0.5, 0.07); + } @Test public void testDotProduct() throws Exception {