diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java index 08ba211a25b..fdbb875c9ca 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java @@ -233,6 +233,7 @@ public class Lang { .withFunctionName("linfnorm", LInfNormEvaluator.class) .withFunctionName("matrixMult", MatrixMultiplyEvaluator.class) .withFunctionName("bicubicSpline", BicubicSplineEvaluator.class) + .withFunctionName("valueAt", ValueAtEvaluator.class) // Boolean Stream Evaluators diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ValueAtEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ValueAtEvaluator.java new file mode 100644 index 00000000000..6df3709a040 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ValueAtEvaluator.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.List; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class ValueAtEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker { + protected static final long serialVersionUID = 1L; + + public ValueAtEvaluator(StreamExpression expression, StreamFactory factory) throws IOException { + super(expression, factory); + } + + @Override + public Object doWork(Object... values) throws IOException { + if(values[0] instanceof List) { + + List c = (List) values[0]; + int index = -1; + if(values.length == 2) { + index = ((Number)values[1]).intValue(); + if(index >= c.size()) { + throw new IOException("Index out of bounds: "+index); + } + } else { + throw new IOException("The valueAt function expects an array and array index as parameters."); + } + return c.get(index); + + } else if(values[0] instanceof Matrix) { + + Matrix c = (Matrix) values[0]; + double[][] data = c.getData(); + int row = -1; + int col = -1; + if(values.length == 3) { + row = ((Number)values[1]).intValue(); + if(row >= data.length) { + throw new IOException("Row index out of bounds: "+row); + } + + col = ((Number)values[2]).intValue(); + if(col >= data[0].length) { + throw new IOException("Column index out of bounds: "+col); + } + + } else { + throw new IOException("The valueAt function expects a matrix and row and column indexes"); + } + return data[row][col]; + } else { + throw new IOException("The valueAt function expects a numeric array or matrix as the first parameter"); + } + + } +} diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java index 498cb2e6fbb..87f5c4611e2 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java @@ -68,7 +68,7 @@ public class TestLang extends LuceneTestCase { TemporalEvaluatorEpoch.FUNCTION_NAME, TemporalEvaluatorWeek.FUNCTION_NAME, TemporalEvaluatorQuarter.FUNCTION_NAME, TemporalEvaluatorDayOfQuarter.FUNCTION_NAME, "abs", "add", "div", "mult", "sub", "log", "pow", "mod", "ceil", "floor", "sin", "asin", "sinh", "cos", "acos", "cosh", "tan", "atan", "tanh", "round", "sqrt", - "cbrt", "coalesce", "uuid", "if", "convert"}; + "cbrt", "coalesce", "uuid", "if", "convert", "valueAt"}; @Test public void testLang() { @@ -85,7 +85,7 @@ public class TestLang extends LuceneTestCase { assertTrue("Testing function:"+func, registeredFunctions.containsKey(func)); } - //Check that ech function that is registered is expected. + //Check that each function that is registered is expected. Set keys = registeredFunctions.keySet(); for(String key : keys) { assertTrue("Testing key:"+key, functions.contains(key)); diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java index abc1c214b79..07570a9dbf5 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java @@ -2354,6 +2354,31 @@ public class MathExpressionTest extends SolrCloudTestCase { } } + + @Test + public void testValueAt() throws Exception { + String cexpr = "let(echo=true, " + + " b=array(1,2,3,4), " + + " c=matrix(array(5,6,7), " + + " array(8,9,10)), " + + " d=valueAt(b, 3)," + + " e=valueAt(c, 1, 0))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + Number value1 = (Number)tuples.get(0).get("d"); + Number value2 = (Number)tuples.get(0).get("e"); + assertEquals(value1.intValue(), 4); + assertEquals(value2.intValue(), 8); + } + + @Test public void testBetaDistribution() throws Exception { String cexpr = "let(a=sample(betaDistribution(1, 5), 50000), b=hist(a, 11), c=col(b, N))";