mirror of https://github.com/apache/lucene.git
SOLR-11241: Add discrete counting and probability Stream Evaluators
This commit is contained in:
parent
1fe611d094
commit
4496612120
|
@ -94,7 +94,7 @@ org.apache.calcite.version = 1.13.0
|
|||
/org.apache.commons/commons-compress = 1.11
|
||||
/org.apache.commons/commons-exec = 1.3
|
||||
/org.apache.commons/commons-lang3 = 3.6
|
||||
/org.apache.commons/commons-math3 = 3.4.1
|
||||
/org.apache.commons/commons-math3 = 3.6.1
|
||||
|
||||
org.apache.curator.version = 2.8.0
|
||||
/org.apache.curator/curator-client = ${org.apache.curator.version}
|
||||
|
|
|
@ -29,6 +29,7 @@ import java.util.Iterator;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.commons.math3.distribution.PoissonDistribution;
|
||||
import org.apache.solr.client.solrj.io.ModelCache;
|
||||
import org.apache.solr.client.solrj.io.SolrClientCache;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
|
@ -249,6 +250,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
|
|||
.withFunctionName("copyOfRange", CopyOfRangeEvaluator.class)
|
||||
.withFunctionName("copyOf", CopyOfEvaluator.class)
|
||||
.withFunctionName("cov", CovarianceEvaluator.class)
|
||||
.withFunctionName("corr", CorrelationEvaluator.class)
|
||||
.withFunctionName("describe", DescribeEvaluator.class)
|
||||
.withFunctionName("distance", EuclideanDistanceEvaluator.class)
|
||||
.withFunctionName("empiricalDistribution", EmpiricalDistributionEvaluator.class)
|
||||
|
@ -281,8 +283,17 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
|
|||
.withFunctionName("ebeDivide", EBEDivideEvaluator.class)
|
||||
.withFunctionName("dotProduct", DotProductEvaluator.class)
|
||||
.withFunctionName("cosineSimilarity", CosineSimilarityEvaluator.class)
|
||||
.withFunctionName("freqTable", FrequencyTableEvaluator.class)
|
||||
.withFunctionName("uniformIntegerDistribution", UniformIntegerDistributionEvaluator.class)
|
||||
.withFunctionName("binomialDistribution", BinomialDistributionEvaluator.class)
|
||||
.withFunctionName("poissonDistribution", PoissonDistributionEvaluator.class)
|
||||
.withFunctionName("enumeratedDistribution", EnumeratedDistributionEvaluator.class)
|
||||
.withFunctionName("probability", ProbabilityEvaluator.class)
|
||||
|
||||
// Boolean Stream Evaluators
|
||||
|
||||
|
||||
|
||||
.withFunctionName("and", AndEvaluator.class)
|
||||
.withFunctionName("eor", ExclusiveOrEvaluator.class)
|
||||
.withFunctionName("eq", EqualToEvaluator.class)
|
||||
|
@ -331,7 +342,6 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
|
|||
.withFunctionName("cbrt", CubedRootEvaluator.class)
|
||||
.withFunctionName("coalesce", CoalesceEvaluator.class)
|
||||
.withFunctionName("uuid", UuidEvaluator.class)
|
||||
.withFunctionName("corr", CorrelationEvaluator.class)
|
||||
|
||||
// Conditional Stream Evaluators
|
||||
.withFunctionName("if", IfThenElseEvaluator.class)
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.distribution.BinomialDistribution;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class BinomialDistributionEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||
|
||||
private static final long serialVersionUID = 1;
|
||||
|
||||
public BinomialDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object first, Object second) throws IOException{
|
||||
if(null == first){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(null == second){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||
}
|
||||
|
||||
Number numberOfTrials = (Number)first;
|
||||
Number successProb = (Number)second;
|
||||
|
||||
return new BinomialDistribution(numberOfTrials.intValue(), successProb.doubleValue());
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ import java.io.IOException;
|
|||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.distribution.RealDistribution;
|
||||
import org.apache.commons.math3.distribution.IntegerDistribution;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
|
@ -38,17 +39,21 @@ public class CumulativeProbabilityEvaluator extends RecursiveObjectEvaluator imp
|
|||
if(null == second){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(!(first instanceof RealDistribution)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a RealDistribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
if(!(first instanceof RealDistribution) && !(first instanceof IntegerDistribution)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a real or integer Distribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
if(!(second instanceof Number)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
if(first instanceof RealDistribution) {
|
||||
RealDistribution rd = (RealDistribution) first;
|
||||
Number predictOver = (Number) second;
|
||||
|
||||
return rd.cumulativeProbability(predictOver.doubleValue());
|
||||
} else {
|
||||
IntegerDistribution id = (IntegerDistribution) first;
|
||||
Number predictOver = (Number) second;
|
||||
return id.cumulativeProbability(predictOver.intValue());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class EnumeratedDistributionEvaluator extends RecursiveNumericEvaluator implements OneValueWorker {
|
||||
|
||||
private static final long serialVersionUID = 1;
|
||||
|
||||
public EnumeratedDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object first) throws IOException{
|
||||
if(null == first){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||
}
|
||||
|
||||
int[] samples = ((List)first).stream().mapToInt(value -> ((BigDecimal) value).intValue()).toArray();
|
||||
return new EnumeratedIntegerDistribution(samples);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.commons.math3.random.EmpiricalDistribution;
|
||||
import org.apache.commons.math3.stat.Frequency;
|
||||
|
||||
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class FrequencyTableEvaluator extends RecursiveNumericEvaluator implements ManyValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public FrequencyTableEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
|
||||
if(containedEvaluators.size() < 1){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting at least one value but found %d",expression,containedEvaluators.size()));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object... values) throws IOException {
|
||||
if(Arrays.stream(values).anyMatch(item -> null == item)){
|
||||
return null;
|
||||
}
|
||||
|
||||
List<?> sourceValues;
|
||||
|
||||
if(values.length == 1){
|
||||
sourceValues = values[0] instanceof List<?> ? (List<?>)values[0] : Arrays.asList(values[0]);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting at least one value but found %d",toExpression(constructingFactory),containedEvaluators.size()));
|
||||
}
|
||||
|
||||
Frequency frequency = new Frequency();
|
||||
|
||||
for(Object o : sourceValues) {
|
||||
Number number = (Number)o;
|
||||
frequency.addValue(number.longValue());
|
||||
}
|
||||
|
||||
List<Tuple> histogramBins = new ArrayList<>();
|
||||
|
||||
Iterator iterator = frequency.valuesIterator();
|
||||
|
||||
while(iterator.hasNext()){
|
||||
Long value = (Long)iterator.next();
|
||||
Map<String,Number> map = new HashMap<>();
|
||||
map.put("value", value.longValue());
|
||||
map.put("count", frequency.getCount(value));
|
||||
map.put("cumFreq", frequency.getCumFreq(value));
|
||||
map.put("cumPct", frequency.getCumPct(value));
|
||||
map.put("pct", frequency.getPct(value));
|
||||
histogramBins.add(new Tuple(map));
|
||||
}
|
||||
return histogramBins;
|
||||
}
|
||||
}
|
|
@ -26,6 +26,7 @@ import java.util.Map;
|
|||
|
||||
import org.apache.commons.math3.random.EmpiricalDistribution;
|
||||
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
|
@ -68,7 +69,7 @@ public class HistogramEvaluator extends RecursiveNumericEvaluator implements Man
|
|||
EmpiricalDistribution distribution = new EmpiricalDistribution(bins);
|
||||
distribution.load(((List<?>)sourceValues).stream().mapToDouble(value -> ((Number)value).doubleValue()).toArray());;
|
||||
|
||||
List<Map<String,Number>> histogramBins = new ArrayList<>();
|
||||
List<Tuple> histogramBins = new ArrayList<>();
|
||||
for(SummaryStatistics binSummary : distribution.getBinStats()) {
|
||||
Map<String,Number> map = new HashMap<>();
|
||||
map.put("max", binSummary.getMax());
|
||||
|
@ -78,7 +79,9 @@ public class HistogramEvaluator extends RecursiveNumericEvaluator implements Man
|
|||
map.put("sum", binSummary.getSum());
|
||||
map.put("N", binSummary.getN());
|
||||
map.put("var", binSummary.getVariance());
|
||||
histogramBins.add(map);
|
||||
map.put("cumProb", distribution.cumulativeProbability(binSummary.getMean()));
|
||||
map.put("prob", distribution.probability(binSummary.getMin(), binSummary.getMax()));
|
||||
histogramBins.add(new Tuple(map));
|
||||
}
|
||||
|
||||
return histogramBins;
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.distribution.PoissonDistribution;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class PoissonDistributionEvaluator extends RecursiveNumericEvaluator implements OneValueWorker {
|
||||
|
||||
private static final long serialVersionUID = 1;
|
||||
|
||||
public PoissonDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object first) throws IOException{
|
||||
if(null == first){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||
}
|
||||
|
||||
Number mean = (Number)first;
|
||||
|
||||
return new PoissonDistribution(mean.intValue());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.distribution.IntegerDistribution;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class ProbabilityEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public ProbabilityEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object first, Object second) throws IOException{
|
||||
if(null == first){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(null == second){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(!(first instanceof IntegerDistribution)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a IntegerDistribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
if(!(second instanceof Number)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
IntegerDistribution d = (IntegerDistribution) first;
|
||||
Number predictOver = (Number) second;
|
||||
return d.probability(predictOver.intValue());
|
||||
}
|
||||
}
|
|
@ -22,6 +22,7 @@ import java.util.Arrays;
|
|||
import java.util.Locale;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.apache.commons.math3.distribution.IntegerDistribution;
|
||||
import org.apache.commons.math3.distribution.RealDistribution;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
@ -42,16 +43,19 @@ public class SampleEvaluator extends RecursiveObjectEvaluator implements TwoValu
|
|||
if(null == second){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(!(first instanceof RealDistribution)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a RealDistribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
if(!(first instanceof RealDistribution) && !(first instanceof IntegerDistribution)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a Real or Integer Distribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
if(!(second instanceof Number)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
if(first instanceof RealDistribution) {
|
||||
RealDistribution realDistribution = (RealDistribution) first;
|
||||
|
||||
return Arrays.stream(realDistribution.sample(((Number) second).intValue())).mapToObj(item -> item).collect(Collectors.toList());
|
||||
} else {
|
||||
IntegerDistribution integerDistribution = (IntegerDistribution) first;
|
||||
return Arrays.stream(integerDistribution.sample(((Number) second).intValue())).mapToObj(item -> item).collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class UniformIntegerDistributionEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||
|
||||
private static final long serialVersionUID = 1;
|
||||
|
||||
public UniformIntegerDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object first, Object second) throws IOException{
|
||||
if(null == first){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(null == second){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||
}
|
||||
|
||||
Number lower = (Number)first;
|
||||
Number upper = (Number)second;
|
||||
|
||||
return new UniformIntegerDistribution(lower.intValue(), upper.intValue());
|
||||
}
|
||||
}
|
|
@ -6155,6 +6155,48 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
assertTrue(out.get(5).intValue() == 2);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testFreqTable() throws Exception {
|
||||
String cexpr = "freqTable(array(2,4,6,8,10,12,12,4,8,8,8,2))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
List<Map<String,Number>> out = (List<Map<String, Number>>)tuples.get(0).get("return-value");
|
||||
assertTrue(out.size() == 6);
|
||||
Map<String, Number> bucket = out.get(0);
|
||||
assertEquals(bucket.get("value").longValue(), 2);
|
||||
assertEquals(bucket.get("count").longValue(), 2);
|
||||
|
||||
bucket = out.get(1);
|
||||
assertEquals(bucket.get("value").longValue(), 4);
|
||||
assertEquals(bucket.get("count").longValue(), 2);
|
||||
|
||||
bucket = out.get(2);
|
||||
assertEquals(bucket.get("value").longValue(), 6);
|
||||
assertEquals(bucket.get("count").longValue(), 1);
|
||||
|
||||
bucket = out.get(3);
|
||||
assertEquals(bucket.get("value").longValue(), 8);
|
||||
assertEquals(bucket.get("count").longValue(), 4);
|
||||
|
||||
bucket = out.get(4);
|
||||
assertEquals(bucket.get("value").longValue(), 10);
|
||||
assertEquals(bucket.get("count").longValue(), 1);
|
||||
|
||||
bucket = out.get(5);
|
||||
assertEquals(bucket.get("value").longValue(), 12);
|
||||
assertEquals(bucket.get("count").longValue(), 2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testCosineSimilarity() throws Exception {
|
||||
String cexpr = "cosineSimilarity(array(2,4,6,8),array(1,1,3,4))";
|
||||
|
@ -6171,7 +6213,109 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
assertTrue(cs.doubleValue() == 0.9838197164968291);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPoissonDistribution() throws Exception {
|
||||
String cexpr = "let(a=poissonDistribution(100)," +
|
||||
" b=sample(a, 10000)," +
|
||||
" tuple(d=describe(b), " +
|
||||
" p=probability(a, 100), " +
|
||||
" c=cumulativeProbability(a, 100)))";
|
||||
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
Map map = (Map)tuples.get(0).get("d");
|
||||
Number mean = (Number)map.get("mean");
|
||||
Number var = (Number)map.get("var");
|
||||
//The mean and variance should be almost the same for poisson distribution
|
||||
assertEquals(mean.doubleValue(), var.doubleValue(), 3.0);
|
||||
Number prob = (Number)tuples.get(0).get("p");
|
||||
assertEquals(prob.doubleValue(), 0.03986099680914713, 0.0);
|
||||
Number cprob = (Number)tuples.get(0).get("c");
|
||||
assertEquals(cprob.doubleValue(), 0.5265621985303708, 0.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBinomialDistribution() throws Exception {
|
||||
String cexpr = "let(a=binomialDistribution(100, .50)," +
|
||||
" b=sample(a, 10000)," +
|
||||
" tuple(d=describe(b), " +
|
||||
" p=probability(a, 50), " +
|
||||
" c=cumulativeProbability(a, 50)))";
|
||||
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
Number prob = (Number)tuples.get(0).get("p");
|
||||
assertEquals(prob.doubleValue(),0.07958923738717877, 0.0);
|
||||
Number cprob = (Number)tuples.get(0).get("c");
|
||||
assertEquals(cprob.doubleValue(), 0.5397946186935851, 0.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUniformIntegerDistribution() throws Exception {
|
||||
String cexpr = "let(a=uniformIntegerDistribution(1, 10)," +
|
||||
" b=sample(a, 10000)," +
|
||||
" tuple(d=describe(b), " +
|
||||
" p=probability(a, 5), " +
|
||||
" c=cumulativeProbability(a, 5)))";
|
||||
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
Map map = (Map)tuples.get(0).get("d");
|
||||
Number N = (Number)map.get("N");
|
||||
assertEquals(N.intValue(), 10000);
|
||||
Number prob = (Number)tuples.get(0).get("p");
|
||||
assertEquals(prob.doubleValue(), 0.1, 0.0);
|
||||
Number cprob = (Number)tuples.get(0).get("c");
|
||||
assertEquals(cprob.doubleValue(), 0.5, 0.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEnumeratedDistribution() throws Exception {
|
||||
String cexpr = "let(a=uniformIntegerDistribution(1, 10)," +
|
||||
" b=sample(a, 10000)," +
|
||||
" c=enumeratedDistribution(b),"+
|
||||
" tuple(d=describe(b), " +
|
||||
" p=probability(c, 5), " +
|
||||
" c=cumulativeProbability(c, 5)))";
|
||||
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
Map map = (Map)tuples.get(0).get("d");
|
||||
Number N = (Number)map.get("N");
|
||||
assertEquals(N.intValue(), 10000);
|
||||
Number prob = (Number)tuples.get(0).get("p");
|
||||
assertEquals(prob.doubleValue(), 0.1, 0.07);
|
||||
Number cprob = (Number)tuples.get(0).get("c");
|
||||
assertEquals(cprob.doubleValue(), 0.5, 0.07);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDotProduct() throws Exception {
|
||||
|
|
Loading…
Reference in New Issue