diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java index a4ac55934c3..7441fd3efd2 100644 --- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java +++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java @@ -172,6 +172,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, .withFunctionName("regress", RegressionEvaluator.class) .withFunctionName("cov", CovarianceEvaluator.class) .withFunctionName("conv", ConvolutionEvaluator.class) + .withFunctionName("normalize", NormalizeEvaluator.class) // metrics .withFunctionName("min", MinMetric.class) diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/NormalizeEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/NormalizeEvaluator.java new file mode 100644 index 00000000000..e011933ad69 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/NormalizeEvaluator.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.stream; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.math3.stat.StatUtils; +import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.eval.ComplexEvaluator; +import org.apache.solr.client.solrj.io.eval.StreamEvaluator; +import org.apache.solr.client.solrj.io.stream.expr.Explanation; +import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType; +import org.apache.solr.client.solrj.io.stream.expr.Expressible; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class NormalizeEvaluator extends ComplexEvaluator implements Expressible { + + private static final long serialVersionUID = 1; + + public NormalizeEvaluator(StreamExpression expression, StreamFactory factory) throws IOException { + super(expression, factory); + } + + public List evaluate(Tuple tuple) throws IOException { + StreamEvaluator colEval1 = subEvaluators.get(0); + + List numbers1 = (List)colEval1.evaluate(tuple); + double[] column1 = new double[numbers1.size()]; + + for(int i=0; i normalizeList = new ArrayList(normalized.length); + for(double d : normalized) { + normalizeList.add(d); + } + + return normalizeList; + } + + @Override + public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException { + StreamExpression expression = new StreamExpression(factory.getFunctionName(getClass())); + return expression; + } + + @Override + public Explanation toExplanation(StreamFactory factory) throws IOException { + return new Explanation(nodeId.toString()) + .withExpressionType(ExpressionType.EVALUATOR) + .withFunctionName(factory.getFunctionName(getClass())) + .withImplementingClass(getClass().getName()) + .withExpression(toExpression(factory).toString()); + } +} \ No newline at end of file diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java index f674236a400..25f09258372 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java @@ -570,7 +570,7 @@ public class StreamExpressionTest extends SolrCloudTestCase { .add(id, "0", "a_i", "1", "a_f", "0", "s_multi", "aaa", "s_multi", "bbb", "i_multi", "100", "i_multi", "200") .add(id, "2", "a_s", "hello2", "a_i", "3", "a_f", "0") .add(id, "3", "a_s", "hello3", "a_i", "4", "a_f", "3") - .add(id, "4", "a_s", "hello4", "a_f", "4") + .add(id, "4", "a_s", "hello4", "a_f", "4") .add(id, "1", "a_s", "hello1", "a_i", "2", "a_f", "1") .commit(cluster.getSolrClient(), COLLECTIONORALIAS); @@ -4049,7 +4049,7 @@ public class StreamExpressionTest extends SolrCloudTestCase { false, true, TIMEOUT); new UpdateRequest() - .add(id, "0", "a_s", "hello0", "a_i", "0", "a_f", "0", "s_multi", "aaaa", "s_multi", "bbbb", "i_multi", "4", "i_multi", "7") + .add(id, "0", "a_s", "hello0", "a_i", "0", "a_f", "0", "s_multi", "aaaa", "s_multi", "bbbb", "i_multi", "4", "i_multi", "7") .add(id, "2", "a_s", "hello2", "a_i", "2", "a_f", "0", "s_multi", "aaaa1", "s_multi", "bbbb1", "i_multi", "44", "i_multi", "77") .add(id, "3", "a_s", "hello3", "a_i", "3", "a_f", "3", "s_multi", "aaaa2", "s_multi", "bbbb2", "i_multi", "444", "i_multi", "777") .add(id, "4", "a_s", "hello4", "a_i", "4", "a_f", "4", "s_multi", "aaaa3", "s_multi", "bbbb3", "i_multi", "4444", "i_multi", "7777") @@ -5401,6 +5401,59 @@ public class StreamExpressionTest extends SolrCloudTestCase { assertTrue(prediction == 600.0D); } + @Test + public void testNormalize() throws Exception { + UpdateRequest updateRequest = new UpdateRequest(); + + updateRequest.add(id, "1", "price_f", "100.0", "col_s", "a", "order_i", "1"); + updateRequest.add(id, "2", "price_f", "200.0", "col_s", "a", "order_i", "2"); + updateRequest.add(id, "3", "price_f", "300.0", "col_s", "a", "order_i", "3"); + updateRequest.add(id, "4", "price_f", "100.0", "col_s", "a", "order_i", "4"); + updateRequest.add(id, "5", "price_f", "200.0", "col_s", "a", "order_i", "5"); + updateRequest.add(id, "6", "price_f", "400.0", "col_s", "a", "order_i", "6"); + updateRequest.add(id, "7", "price_f", "600.0", "col_s", "a", "order_i", "7"); + + updateRequest.commit(cluster.getSolrClient(), COLLECTIONORALIAS); + + String expr1 = "search("+COLLECTIONORALIAS+", q=\"col_s:a\", fl=\"price_f, order_i\", sort=\"order_i asc\")"; + String cexpr = "let(a="+expr1+", c=col(a, price_f), tuple(n=normalize(c), c=c))"; + + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + Tuple tuple = tuples.get(0); + List col = (List)tuple.get("c"); + List normalized = (List)tuple.get("n"); + + assertTrue(col.size() == normalized.size()); + + double total = 0.0D; + + for(double d : normalized) { + total += d; + } + + double mean = total/normalized.size(); + assert(Math.round(mean) == 0); + + double sd = 0; + for (int i = 0; i < normalized.size(); i++) + { + sd += Math.pow(normalized.get(i) - mean, 2) / normalized.size(); + } + double standardDeviation = Math.sqrt(sd); + + assertTrue(Math.round(standardDeviation) == 1); + } + @Test public void testListStream() throws Exception { UpdateRequest updateRequest = new UpdateRequest(); @@ -5717,7 +5770,7 @@ public class StreamExpressionTest extends SolrCloudTestCase { cluster.getSolrClient().commit("destination"); paramsLoc = new ModifiableSolrParams(); paramsLoc.set("expr", "search(destination, q=\"*:*\", fl=\"id, body_t, field_i\", rows=1000, sort=\"field_i asc\")"); - paramsLoc.set("qt","/stream"); + paramsLoc.set("qt", "/stream"); SolrStream solrStream = new SolrStream(url, paramsLoc); List tuples = getTuples(solrStream);