diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java index 949a040bade..67d6f7f7715 100644 --- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java +++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java @@ -284,6 +284,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, .withFunctionName("ttest", TTestEvaluator.class) .withFunctionName("pairedTtest", PairedTTestEvaluator.class) .withFunctionName("multiVariateNormalDistribution", MultiVariateNormalDistributionEvaluator.class) + .withFunctionName("integrate", IntegrateEvaluator.class) // Boolean Stream Evaluators diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/IntegrateEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/IntegrateEvaluator.java new file mode 100644 index 00000000000..277748c8dd4 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/IntegrateEvaluator.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.Locale; + +import org.apache.commons.math3.analysis.UnivariateFunction; +import org.apache.commons.math3.analysis.integration.RombergIntegrator; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class IntegrateEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker { + protected static final long serialVersionUID = 1L; + + public IntegrateEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + } + + @Override + public Object doWork(Object... values) throws IOException { + + if(values.length != 3) { + throw new IOException("The integrate function requires 3 parameters"); + } + + if (!(values[0] instanceof VectorFunction)) { + throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a FunctionVector", toExpression(constructingFactory), values[0].getClass().getSimpleName())); + } + + VectorFunction vectorFunction = (VectorFunction) values[0]; + if(!(vectorFunction.getFunction() instanceof UnivariateFunction)) { + throw new IOException("Cannot evaluate integral from parameter."); + } + + Number min = null; + Number max = null; + + if(values[1] instanceof Number) { + min = (Number) values[1]; + } else { + throw new IOException("The second parameter of the integrate function must be a number"); + } + + if(values[2] instanceof Number ) { + max = (Number) values[2]; + } else { + throw new IOException("The third parameter of the integrate function must be a number"); + } + + UnivariateFunction func = (UnivariateFunction)vectorFunction.getFunction(); + + RombergIntegrator rombergIntegrator = new RombergIntegrator(); + return rombergIntegrator.integrate(5000, func, min.doubleValue(), max.doubleValue()); + } +} diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java index 01ba1ddafd7..a9f0f66461e 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java @@ -7227,6 +7227,32 @@ public class StreamExpressionTest extends SolrCloudTestCase { } + @Test + public void testIntegrate() throws Exception { + String cexpr = "let(echo=true, " + + "a=sequence(50, 1, 0), " + + "b=spline(a), " + + "c=integrate(b, 0, 49), " + + "d=integrate(b, 0, 20), " + + "e=integrate(b, 20, 49))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + Number integral = (Number)tuples.get(0).get("c"); + assertEquals(integral.doubleValue(), 49, 0.0); + integral = (Number)tuples.get(0).get("d"); + assertEquals(integral.doubleValue(), 20, 0.0); + integral = (Number)tuples.get(0).get("e"); + assertEquals(integral.doubleValue(), 29, 0.0); + } + + @Test public void testLoess() throws Exception { String cexpr = "let(echo=true," +