mirror of https://github.com/apache/lucene.git
SOLR-11422: Add probabilities parameter to the enumeratedDistribution Stream Evaluator
This commit is contained in:
parent
71eb59e043
commit
c322e36926
|
@ -25,7 +25,7 @@ import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
|
|||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class EnumeratedDistributionEvaluator extends RecursiveNumericEvaluator implements OneValueWorker {
|
||||
public class EnumeratedDistributionEvaluator extends RecursiveNumericEvaluator implements ManyValueWorker {
|
||||
|
||||
private static final long serialVersionUID = 1;
|
||||
|
||||
|
@ -34,12 +34,21 @@ public class EnumeratedDistributionEvaluator extends RecursiveNumericEvaluator i
|
|||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object first) throws IOException{
|
||||
if(null == first){
|
||||
public Object doWork(Object... values) throws IOException{
|
||||
if(values.length == 0){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||
}
|
||||
|
||||
int[] samples = ((List)first).stream().mapToInt(value -> ((BigDecimal) value).intValue()).toArray();
|
||||
return new EnumeratedIntegerDistribution(samples);
|
||||
if(values.length == 1) {
|
||||
List<Number> first = (List<Number>)values[0];
|
||||
int[] samples = ((List) first).stream().mapToInt(value -> ((BigDecimal) value).intValue()).toArray();
|
||||
return new EnumeratedIntegerDistribution(samples);
|
||||
} else {
|
||||
List<Number> first = (List<Number>)values[0];
|
||||
List<Number> second = (List<Number>)values[1];
|
||||
int[] singletons = ((List) first).stream().mapToInt(value -> ((BigDecimal) value).intValue()).toArray();
|
||||
double[] probs = ((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray();
|
||||
return new EnumeratedIntegerDistribution(singletons, probs);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -6480,6 +6480,26 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
assertEquals(prob.doubleValue(), 0.1, 0.07);
|
||||
Number cprob = (Number)tuples.get(0).get("c");
|
||||
assertEquals(cprob.doubleValue(), 0.5, 0.07);
|
||||
|
||||
|
||||
cexpr = "let(a=sample(enumeratedDistribution(array(1,2,3,4), array(40, 30, 20, 10)), 50000),"+
|
||||
"b=freqTable(a),"+
|
||||
"y=col(b, pct))";
|
||||
|
||||
paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
solrStream = new SolrStream(url, paramsLoc);
|
||||
context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
List<Number> freqs = (List<Number>)tuples.get(0).get("y");
|
||||
assertEquals(freqs.get(0).doubleValue(), .40, .03);
|
||||
assertEquals(freqs.get(1).doubleValue(), .30, .03);
|
||||
assertEquals(freqs.get(2).doubleValue(), .20, .03);
|
||||
assertEquals(freqs.get(3).doubleValue(), .10, .03);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue