mirror of https://github.com/apache/lucene.git
SOLR-11241: Add discrete counting and probability Stream Evaluators
This commit is contained in:
parent
1fe611d094
commit
4496612120
|
@ -94,7 +94,7 @@ org.apache.calcite.version = 1.13.0
|
||||||
/org.apache.commons/commons-compress = 1.11
|
/org.apache.commons/commons-compress = 1.11
|
||||||
/org.apache.commons/commons-exec = 1.3
|
/org.apache.commons/commons-exec = 1.3
|
||||||
/org.apache.commons/commons-lang3 = 3.6
|
/org.apache.commons/commons-lang3 = 3.6
|
||||||
/org.apache.commons/commons-math3 = 3.4.1
|
/org.apache.commons/commons-math3 = 3.6.1
|
||||||
|
|
||||||
org.apache.curator.version = 2.8.0
|
org.apache.curator.version = 2.8.0
|
||||||
/org.apache.curator/curator-client = ${org.apache.curator.version}
|
/org.apache.curator/curator-client = ${org.apache.curator.version}
|
||||||
|
|
|
@ -29,6 +29,7 @@ import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.distribution.PoissonDistribution;
|
||||||
import org.apache.solr.client.solrj.io.ModelCache;
|
import org.apache.solr.client.solrj.io.ModelCache;
|
||||||
import org.apache.solr.client.solrj.io.SolrClientCache;
|
import org.apache.solr.client.solrj.io.SolrClientCache;
|
||||||
import org.apache.solr.client.solrj.io.Tuple;
|
import org.apache.solr.client.solrj.io.Tuple;
|
||||||
|
@ -249,6 +250,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
|
||||||
.withFunctionName("copyOfRange", CopyOfRangeEvaluator.class)
|
.withFunctionName("copyOfRange", CopyOfRangeEvaluator.class)
|
||||||
.withFunctionName("copyOf", CopyOfEvaluator.class)
|
.withFunctionName("copyOf", CopyOfEvaluator.class)
|
||||||
.withFunctionName("cov", CovarianceEvaluator.class)
|
.withFunctionName("cov", CovarianceEvaluator.class)
|
||||||
|
.withFunctionName("corr", CorrelationEvaluator.class)
|
||||||
.withFunctionName("describe", DescribeEvaluator.class)
|
.withFunctionName("describe", DescribeEvaluator.class)
|
||||||
.withFunctionName("distance", EuclideanDistanceEvaluator.class)
|
.withFunctionName("distance", EuclideanDistanceEvaluator.class)
|
||||||
.withFunctionName("empiricalDistribution", EmpiricalDistributionEvaluator.class)
|
.withFunctionName("empiricalDistribution", EmpiricalDistributionEvaluator.class)
|
||||||
|
@ -281,9 +283,18 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
|
||||||
.withFunctionName("ebeDivide", EBEDivideEvaluator.class)
|
.withFunctionName("ebeDivide", EBEDivideEvaluator.class)
|
||||||
.withFunctionName("dotProduct", DotProductEvaluator.class)
|
.withFunctionName("dotProduct", DotProductEvaluator.class)
|
||||||
.withFunctionName("cosineSimilarity", CosineSimilarityEvaluator.class)
|
.withFunctionName("cosineSimilarity", CosineSimilarityEvaluator.class)
|
||||||
|
.withFunctionName("freqTable", FrequencyTableEvaluator.class)
|
||||||
|
.withFunctionName("uniformIntegerDistribution", UniformIntegerDistributionEvaluator.class)
|
||||||
|
.withFunctionName("binomialDistribution", BinomialDistributionEvaluator.class)
|
||||||
|
.withFunctionName("poissonDistribution", PoissonDistributionEvaluator.class)
|
||||||
|
.withFunctionName("enumeratedDistribution", EnumeratedDistributionEvaluator.class)
|
||||||
|
.withFunctionName("probability", ProbabilityEvaluator.class)
|
||||||
|
|
||||||
// Boolean Stream Evaluators
|
// Boolean Stream Evaluators
|
||||||
.withFunctionName("and", AndEvaluator.class)
|
|
||||||
|
|
||||||
|
|
||||||
|
.withFunctionName("and", AndEvaluator.class)
|
||||||
.withFunctionName("eor", ExclusiveOrEvaluator.class)
|
.withFunctionName("eor", ExclusiveOrEvaluator.class)
|
||||||
.withFunctionName("eq", EqualToEvaluator.class)
|
.withFunctionName("eq", EqualToEvaluator.class)
|
||||||
.withFunctionName("gt", GreaterThanEvaluator.class)
|
.withFunctionName("gt", GreaterThanEvaluator.class)
|
||||||
|
@ -331,7 +342,6 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
|
||||||
.withFunctionName("cbrt", CubedRootEvaluator.class)
|
.withFunctionName("cbrt", CubedRootEvaluator.class)
|
||||||
.withFunctionName("coalesce", CoalesceEvaluator.class)
|
.withFunctionName("coalesce", CoalesceEvaluator.class)
|
||||||
.withFunctionName("uuid", UuidEvaluator.class)
|
.withFunctionName("uuid", UuidEvaluator.class)
|
||||||
.withFunctionName("corr", CorrelationEvaluator.class)
|
|
||||||
|
|
||||||
// Conditional Stream Evaluators
|
// Conditional Stream Evaluators
|
||||||
.withFunctionName("if", IfThenElseEvaluator.class)
|
.withFunctionName("if", IfThenElseEvaluator.class)
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.client.solrj.io.eval;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Locale;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.distribution.BinomialDistribution;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
|
||||||
|
public class BinomialDistributionEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1;
|
||||||
|
|
||||||
|
public BinomialDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
|
||||||
|
super(expression, factory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object doWork(Object first, Object second) throws IOException{
|
||||||
|
if(null == first){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(null == second){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Number numberOfTrials = (Number)first;
|
||||||
|
Number successProb = (Number)second;
|
||||||
|
|
||||||
|
return new BinomialDistribution(numberOfTrials.intValue(), successProb.doubleValue());
|
||||||
|
}
|
||||||
|
}
|
|
@ -20,6 +20,7 @@ import java.io.IOException;
|
||||||
import java.util.Locale;
|
import java.util.Locale;
|
||||||
|
|
||||||
import org.apache.commons.math3.distribution.RealDistribution;
|
import org.apache.commons.math3.distribution.RealDistribution;
|
||||||
|
import org.apache.commons.math3.distribution.IntegerDistribution;
|
||||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
|
||||||
|
@ -38,17 +39,21 @@ public class CumulativeProbabilityEvaluator extends RecursiveObjectEvaluator imp
|
||||||
if(null == second){
|
if(null == second){
|
||||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||||
}
|
}
|
||||||
if(!(first instanceof RealDistribution)){
|
if(!(first instanceof RealDistribution) && !(first instanceof IntegerDistribution)){
|
||||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a RealDistribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a real or integer Distribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
}
|
}
|
||||||
if(!(second instanceof Number)){
|
if(!(second instanceof Number)){
|
||||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
}
|
}
|
||||||
|
|
||||||
RealDistribution rd = (RealDistribution)first;
|
if(first instanceof RealDistribution) {
|
||||||
Number predictOver = (Number)second;
|
RealDistribution rd = (RealDistribution) first;
|
||||||
|
Number predictOver = (Number) second;
|
||||||
return rd.cumulativeProbability(predictOver.doubleValue());
|
return rd.cumulativeProbability(predictOver.doubleValue());
|
||||||
|
} else {
|
||||||
|
IntegerDistribution id = (IntegerDistribution) first;
|
||||||
|
Number predictOver = (Number) second;
|
||||||
|
return id.cumulativeProbability(predictOver.intValue());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.client.solrj.io.eval;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.math.BigDecimal;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Locale;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
|
||||||
|
public class EnumeratedDistributionEvaluator extends RecursiveNumericEvaluator implements OneValueWorker {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1;
|
||||||
|
|
||||||
|
public EnumeratedDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
|
||||||
|
super(expression, factory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object doWork(Object first) throws IOException{
|
||||||
|
if(null == first){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
|
||||||
|
int[] samples = ((List)first).stream().mapToInt(value -> ((BigDecimal) value).intValue()).toArray();
|
||||||
|
return new EnumeratedIntegerDistribution(samples);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,86 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.client.solrj.io.eval;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Locale;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.random.EmpiricalDistribution;
|
||||||
|
import org.apache.commons.math3.stat.Frequency;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
|
||||||
|
import org.apache.solr.client.solrj.io.Tuple;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
|
||||||
|
public class FrequencyTableEvaluator extends RecursiveNumericEvaluator implements ManyValueWorker {
|
||||||
|
protected static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
public FrequencyTableEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||||
|
super(expression, factory);
|
||||||
|
|
||||||
|
if(containedEvaluators.size() < 1){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting at least one value but found %d",expression,containedEvaluators.size()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object doWork(Object... values) throws IOException {
|
||||||
|
if(Arrays.stream(values).anyMatch(item -> null == item)){
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
List<?> sourceValues;
|
||||||
|
|
||||||
|
if(values.length == 1){
|
||||||
|
sourceValues = values[0] instanceof List<?> ? (List<?>)values[0] : Arrays.asList(values[0]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting at least one value but found %d",toExpression(constructingFactory),containedEvaluators.size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Frequency frequency = new Frequency();
|
||||||
|
|
||||||
|
for(Object o : sourceValues) {
|
||||||
|
Number number = (Number)o;
|
||||||
|
frequency.addValue(number.longValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Tuple> histogramBins = new ArrayList<>();
|
||||||
|
|
||||||
|
Iterator iterator = frequency.valuesIterator();
|
||||||
|
|
||||||
|
while(iterator.hasNext()){
|
||||||
|
Long value = (Long)iterator.next();
|
||||||
|
Map<String,Number> map = new HashMap<>();
|
||||||
|
map.put("value", value.longValue());
|
||||||
|
map.put("count", frequency.getCount(value));
|
||||||
|
map.put("cumFreq", frequency.getCumFreq(value));
|
||||||
|
map.put("cumPct", frequency.getCumPct(value));
|
||||||
|
map.put("pct", frequency.getPct(value));
|
||||||
|
histogramBins.add(new Tuple(map));
|
||||||
|
}
|
||||||
|
return histogramBins;
|
||||||
|
}
|
||||||
|
}
|
|
@ -26,6 +26,7 @@ import java.util.Map;
|
||||||
|
|
||||||
import org.apache.commons.math3.random.EmpiricalDistribution;
|
import org.apache.commons.math3.random.EmpiricalDistribution;
|
||||||
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
|
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
|
||||||
|
import org.apache.solr.client.solrj.io.Tuple;
|
||||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
|
||||||
|
@ -68,7 +69,7 @@ public class HistogramEvaluator extends RecursiveNumericEvaluator implements Man
|
||||||
EmpiricalDistribution distribution = new EmpiricalDistribution(bins);
|
EmpiricalDistribution distribution = new EmpiricalDistribution(bins);
|
||||||
distribution.load(((List<?>)sourceValues).stream().mapToDouble(value -> ((Number)value).doubleValue()).toArray());;
|
distribution.load(((List<?>)sourceValues).stream().mapToDouble(value -> ((Number)value).doubleValue()).toArray());;
|
||||||
|
|
||||||
List<Map<String,Number>> histogramBins = new ArrayList<>();
|
List<Tuple> histogramBins = new ArrayList<>();
|
||||||
for(SummaryStatistics binSummary : distribution.getBinStats()) {
|
for(SummaryStatistics binSummary : distribution.getBinStats()) {
|
||||||
Map<String,Number> map = new HashMap<>();
|
Map<String,Number> map = new HashMap<>();
|
||||||
map.put("max", binSummary.getMax());
|
map.put("max", binSummary.getMax());
|
||||||
|
@ -78,7 +79,9 @@ public class HistogramEvaluator extends RecursiveNumericEvaluator implements Man
|
||||||
map.put("sum", binSummary.getSum());
|
map.put("sum", binSummary.getSum());
|
||||||
map.put("N", binSummary.getN());
|
map.put("N", binSummary.getN());
|
||||||
map.put("var", binSummary.getVariance());
|
map.put("var", binSummary.getVariance());
|
||||||
histogramBins.add(map);
|
map.put("cumProb", distribution.cumulativeProbability(binSummary.getMean()));
|
||||||
|
map.put("prob", distribution.probability(binSummary.getMin(), binSummary.getMax()));
|
||||||
|
histogramBins.add(new Tuple(map));
|
||||||
}
|
}
|
||||||
|
|
||||||
return histogramBins;
|
return histogramBins;
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.client.solrj.io.eval;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Locale;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.distribution.PoissonDistribution;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
|
||||||
|
public class PoissonDistributionEvaluator extends RecursiveNumericEvaluator implements OneValueWorker {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1;
|
||||||
|
|
||||||
|
public PoissonDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
|
||||||
|
super(expression, factory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object doWork(Object first) throws IOException{
|
||||||
|
if(null == first){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Number mean = (Number)first;
|
||||||
|
|
||||||
|
return new PoissonDistribution(mean.intValue());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.client.solrj.io.eval;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Locale;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.distribution.IntegerDistribution;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
|
||||||
|
public class ProbabilityEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
|
||||||
|
protected static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
public ProbabilityEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||||
|
super(expression, factory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object doWork(Object first, Object second) throws IOException{
|
||||||
|
if(null == first){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(null == second){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(!(first instanceof IntegerDistribution)){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a IntegerDistribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
|
}
|
||||||
|
if(!(second instanceof Number)){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
|
}
|
||||||
|
|
||||||
|
IntegerDistribution d = (IntegerDistribution) first;
|
||||||
|
Number predictOver = (Number) second;
|
||||||
|
return d.probability(predictOver.intValue());
|
||||||
|
}
|
||||||
|
}
|
|
@ -22,6 +22,7 @@ import java.util.Arrays;
|
||||||
import java.util.Locale;
|
import java.util.Locale;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.distribution.IntegerDistribution;
|
||||||
import org.apache.commons.math3.distribution.RealDistribution;
|
import org.apache.commons.math3.distribution.RealDistribution;
|
||||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
@ -42,16 +43,19 @@ public class SampleEvaluator extends RecursiveObjectEvaluator implements TwoValu
|
||||||
if(null == second){
|
if(null == second){
|
||||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||||
}
|
}
|
||||||
if(!(first instanceof RealDistribution)){
|
if(!(first instanceof RealDistribution) && !(first instanceof IntegerDistribution)){
|
||||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a RealDistribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a Real or Integer Distribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
}
|
}
|
||||||
if(!(second instanceof Number)){
|
if(!(second instanceof Number)){
|
||||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
}
|
}
|
||||||
|
|
||||||
RealDistribution realDistribution = (RealDistribution)first;
|
|
||||||
|
|
||||||
return Arrays.stream(realDistribution.sample(((Number)second).intValue())).mapToObj(item -> item).collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if(first instanceof RealDistribution) {
|
||||||
|
RealDistribution realDistribution = (RealDistribution) first;
|
||||||
|
return Arrays.stream(realDistribution.sample(((Number) second).intValue())).mapToObj(item -> item).collect(Collectors.toList());
|
||||||
|
} else {
|
||||||
|
IntegerDistribution integerDistribution = (IntegerDistribution) first;
|
||||||
|
return Arrays.stream(integerDistribution.sample(((Number) second).intValue())).mapToObj(item -> item).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.client.solrj.io.eval;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Locale;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
|
||||||
|
public class UniformIntegerDistributionEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1;
|
||||||
|
|
||||||
|
public UniformIntegerDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
|
||||||
|
super(expression, factory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object doWork(Object first, Object second) throws IOException{
|
||||||
|
if(null == first){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(null == second){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Number lower = (Number)first;
|
||||||
|
Number upper = (Number)second;
|
||||||
|
|
||||||
|
return new UniformIntegerDistribution(lower.intValue(), upper.intValue());
|
||||||
|
}
|
||||||
|
}
|
|
@ -6155,6 +6155,48 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
||||||
assertTrue(out.get(5).intValue() == 2);
|
assertTrue(out.get(5).intValue() == 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFreqTable() throws Exception {
|
||||||
|
String cexpr = "freqTable(array(2,4,6,8,10,12,12,4,8,8,8,2))";
|
||||||
|
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||||
|
paramsLoc.set("expr", cexpr);
|
||||||
|
paramsLoc.set("qt", "/stream");
|
||||||
|
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||||
|
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||||
|
StreamContext context = new StreamContext();
|
||||||
|
solrStream.setStreamContext(context);
|
||||||
|
List<Tuple> tuples = getTuples(solrStream);
|
||||||
|
assertTrue(tuples.size() == 1);
|
||||||
|
List<Map<String,Number>> out = (List<Map<String, Number>>)tuples.get(0).get("return-value");
|
||||||
|
assertTrue(out.size() == 6);
|
||||||
|
Map<String, Number> bucket = out.get(0);
|
||||||
|
assertEquals(bucket.get("value").longValue(), 2);
|
||||||
|
assertEquals(bucket.get("count").longValue(), 2);
|
||||||
|
|
||||||
|
bucket = out.get(1);
|
||||||
|
assertEquals(bucket.get("value").longValue(), 4);
|
||||||
|
assertEquals(bucket.get("count").longValue(), 2);
|
||||||
|
|
||||||
|
bucket = out.get(2);
|
||||||
|
assertEquals(bucket.get("value").longValue(), 6);
|
||||||
|
assertEquals(bucket.get("count").longValue(), 1);
|
||||||
|
|
||||||
|
bucket = out.get(3);
|
||||||
|
assertEquals(bucket.get("value").longValue(), 8);
|
||||||
|
assertEquals(bucket.get("count").longValue(), 4);
|
||||||
|
|
||||||
|
bucket = out.get(4);
|
||||||
|
assertEquals(bucket.get("value").longValue(), 10);
|
||||||
|
assertEquals(bucket.get("count").longValue(), 1);
|
||||||
|
|
||||||
|
bucket = out.get(5);
|
||||||
|
assertEquals(bucket.get("value").longValue(), 12);
|
||||||
|
assertEquals(bucket.get("count").longValue(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCosineSimilarity() throws Exception {
|
public void testCosineSimilarity() throws Exception {
|
||||||
String cexpr = "cosineSimilarity(array(2,4,6,8),array(1,1,3,4))";
|
String cexpr = "cosineSimilarity(array(2,4,6,8),array(1,1,3,4))";
|
||||||
|
@ -6171,7 +6213,109 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
||||||
assertTrue(cs.doubleValue() == 0.9838197164968291);
|
assertTrue(cs.doubleValue() == 0.9838197164968291);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPoissonDistribution() throws Exception {
|
||||||
|
String cexpr = "let(a=poissonDistribution(100)," +
|
||||||
|
" b=sample(a, 10000)," +
|
||||||
|
" tuple(d=describe(b), " +
|
||||||
|
" p=probability(a, 100), " +
|
||||||
|
" c=cumulativeProbability(a, 100)))";
|
||||||
|
|
||||||
|
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||||
|
paramsLoc.set("expr", cexpr);
|
||||||
|
paramsLoc.set("qt", "/stream");
|
||||||
|
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||||
|
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||||
|
StreamContext context = new StreamContext();
|
||||||
|
solrStream.setStreamContext(context);
|
||||||
|
List<Tuple> tuples = getTuples(solrStream);
|
||||||
|
assertTrue(tuples.size() == 1);
|
||||||
|
Map map = (Map)tuples.get(0).get("d");
|
||||||
|
Number mean = (Number)map.get("mean");
|
||||||
|
Number var = (Number)map.get("var");
|
||||||
|
//The mean and variance should be almost the same for poisson distribution
|
||||||
|
assertEquals(mean.doubleValue(), var.doubleValue(), 3.0);
|
||||||
|
Number prob = (Number)tuples.get(0).get("p");
|
||||||
|
assertEquals(prob.doubleValue(), 0.03986099680914713, 0.0);
|
||||||
|
Number cprob = (Number)tuples.get(0).get("c");
|
||||||
|
assertEquals(cprob.doubleValue(), 0.5265621985303708, 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBinomialDistribution() throws Exception {
|
||||||
|
String cexpr = "let(a=binomialDistribution(100, .50)," +
|
||||||
|
" b=sample(a, 10000)," +
|
||||||
|
" tuple(d=describe(b), " +
|
||||||
|
" p=probability(a, 50), " +
|
||||||
|
" c=cumulativeProbability(a, 50)))";
|
||||||
|
|
||||||
|
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||||
|
paramsLoc.set("expr", cexpr);
|
||||||
|
paramsLoc.set("qt", "/stream");
|
||||||
|
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||||
|
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||||
|
StreamContext context = new StreamContext();
|
||||||
|
solrStream.setStreamContext(context);
|
||||||
|
List<Tuple> tuples = getTuples(solrStream);
|
||||||
|
assertTrue(tuples.size() == 1);
|
||||||
|
Number prob = (Number)tuples.get(0).get("p");
|
||||||
|
assertEquals(prob.doubleValue(),0.07958923738717877, 0.0);
|
||||||
|
Number cprob = (Number)tuples.get(0).get("c");
|
||||||
|
assertEquals(cprob.doubleValue(), 0.5397946186935851, 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUniformIntegerDistribution() throws Exception {
|
||||||
|
String cexpr = "let(a=uniformIntegerDistribution(1, 10)," +
|
||||||
|
" b=sample(a, 10000)," +
|
||||||
|
" tuple(d=describe(b), " +
|
||||||
|
" p=probability(a, 5), " +
|
||||||
|
" c=cumulativeProbability(a, 5)))";
|
||||||
|
|
||||||
|
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||||
|
paramsLoc.set("expr", cexpr);
|
||||||
|
paramsLoc.set("qt", "/stream");
|
||||||
|
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||||
|
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||||
|
StreamContext context = new StreamContext();
|
||||||
|
solrStream.setStreamContext(context);
|
||||||
|
List<Tuple> tuples = getTuples(solrStream);
|
||||||
|
assertTrue(tuples.size() == 1);
|
||||||
|
Map map = (Map)tuples.get(0).get("d");
|
||||||
|
Number N = (Number)map.get("N");
|
||||||
|
assertEquals(N.intValue(), 10000);
|
||||||
|
Number prob = (Number)tuples.get(0).get("p");
|
||||||
|
assertEquals(prob.doubleValue(), 0.1, 0.0);
|
||||||
|
Number cprob = (Number)tuples.get(0).get("c");
|
||||||
|
assertEquals(cprob.doubleValue(), 0.5, 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEnumeratedDistribution() throws Exception {
|
||||||
|
String cexpr = "let(a=uniformIntegerDistribution(1, 10)," +
|
||||||
|
" b=sample(a, 10000)," +
|
||||||
|
" c=enumeratedDistribution(b),"+
|
||||||
|
" tuple(d=describe(b), " +
|
||||||
|
" p=probability(c, 5), " +
|
||||||
|
" c=cumulativeProbability(c, 5)))";
|
||||||
|
|
||||||
|
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||||
|
paramsLoc.set("expr", cexpr);
|
||||||
|
paramsLoc.set("qt", "/stream");
|
||||||
|
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||||
|
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||||
|
StreamContext context = new StreamContext();
|
||||||
|
solrStream.setStreamContext(context);
|
||||||
|
List<Tuple> tuples = getTuples(solrStream);
|
||||||
|
assertTrue(tuples.size() == 1);
|
||||||
|
Map map = (Map)tuples.get(0).get("d");
|
||||||
|
Number N = (Number)map.get("N");
|
||||||
|
assertEquals(N.intValue(), 10000);
|
||||||
|
Number prob = (Number)tuples.get(0).get("p");
|
||||||
|
assertEquals(prob.doubleValue(), 0.1, 0.07);
|
||||||
|
Number cprob = (Number)tuples.get(0).get("c");
|
||||||
|
assertEquals(cprob.doubleValue(), 0.5, 0.07);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDotProduct() throws Exception {
|
public void testDotProduct() throws Exception {
|
||||||
|
|
Loading…
Reference in New Issue