SOLR-11422: Add probabilities parameter to the enumeratedDistribution Stream Evaluator

This commit is contained in:
Joel Bernstein 2017-09-29 13:10:26 -04:00
parent 71eb59e043
commit c322e36926
2 changed files with 34 additions and 5 deletions

View File

@ -25,7 +25,7 @@ import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class EnumeratedDistributionEvaluator extends RecursiveNumericEvaluator implements OneValueWorker {
public class EnumeratedDistributionEvaluator extends RecursiveNumericEvaluator implements ManyValueWorker {
private static final long serialVersionUID = 1;
@ -34,12 +34,21 @@ public class EnumeratedDistributionEvaluator extends RecursiveNumericEvaluator i
}
@Override
public Object doWork(Object first) throws IOException{
if(null == first){
public Object doWork(Object... values) throws IOException{
if(values.length == 0){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
}
if(values.length == 1) {
List<Number> first = (List<Number>)values[0];
int[] samples = ((List) first).stream().mapToInt(value -> ((BigDecimal) value).intValue()).toArray();
return new EnumeratedIntegerDistribution(samples);
} else {
List<Number> first = (List<Number>)values[0];
List<Number> second = (List<Number>)values[1];
int[] singletons = ((List) first).stream().mapToInt(value -> ((BigDecimal) value).intValue()).toArray();
double[] probs = ((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray();
return new EnumeratedIntegerDistribution(singletons, probs);
}
}
}

View File

@ -6480,6 +6480,26 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertEquals(prob.doubleValue(), 0.1, 0.07);
Number cprob = (Number)tuples.get(0).get("c");
assertEquals(cprob.doubleValue(), 0.5, 0.07);
cexpr = "let(a=sample(enumeratedDistribution(array(1,2,3,4), array(40, 30, 20, 10)), 50000),"+
"b=freqTable(a),"+
"y=col(b, pct))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
List<Number> freqs = (List<Number>)tuples.get(0).get("y");
assertEquals(freqs.get(0).doubleValue(), .40, .03);
assertEquals(freqs.get(1).doubleValue(), .30, .03);
assertEquals(freqs.get(2).doubleValue(), .20, .03);
assertEquals(freqs.get(3).doubleValue(), .10, .03);
}
@Test