SOLR-12221: Add valueAt Stream Evaluator

This commit is contained in:
Joel Bernstein 2018-04-13 13:31:30 -04:00
parent 8462b134ea
commit 487daab629
4 changed files with 102 additions and 2 deletions

View File

@ -233,6 +233,7 @@ public class Lang {
.withFunctionName("linfnorm", LInfNormEvaluator.class) .withFunctionName("linfnorm", LInfNormEvaluator.class)
.withFunctionName("matrixMult", MatrixMultiplyEvaluator.class) .withFunctionName("matrixMult", MatrixMultiplyEvaluator.class)
.withFunctionName("bicubicSpline", BicubicSplineEvaluator.class) .withFunctionName("bicubicSpline", BicubicSplineEvaluator.class)
.withFunctionName("valueAt", ValueAtEvaluator.class)
// Boolean Stream Evaluators // Boolean Stream Evaluators

View File

@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class ValueAtEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
protected static final long serialVersionUID = 1L;
public ValueAtEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
super(expression, factory);
}
@Override
public Object doWork(Object... values) throws IOException {
if(values[0] instanceof List) {
List<Number> c = (List<Number>) values[0];
int index = -1;
if(values.length == 2) {
index = ((Number)values[1]).intValue();
if(index >= c.size()) {
throw new IOException("Index out of bounds: "+index);
}
} else {
throw new IOException("The valueAt function expects an array and array index as parameters.");
}
return c.get(index);
} else if(values[0] instanceof Matrix) {
Matrix c = (Matrix) values[0];
double[][] data = c.getData();
int row = -1;
int col = -1;
if(values.length == 3) {
row = ((Number)values[1]).intValue();
if(row >= data.length) {
throw new IOException("Row index out of bounds: "+row);
}
col = ((Number)values[2]).intValue();
if(col >= data[0].length) {
throw new IOException("Column index out of bounds: "+col);
}
} else {
throw new IOException("The valueAt function expects a matrix and row and column indexes");
}
return data[row][col];
} else {
throw new IOException("The valueAt function expects a numeric array or matrix as the first parameter");
}
}
}

View File

@ -68,7 +68,7 @@ public class TestLang extends LuceneTestCase {
TemporalEvaluatorEpoch.FUNCTION_NAME, TemporalEvaluatorWeek.FUNCTION_NAME, TemporalEvaluatorQuarter.FUNCTION_NAME, TemporalEvaluatorEpoch.FUNCTION_NAME, TemporalEvaluatorWeek.FUNCTION_NAME, TemporalEvaluatorQuarter.FUNCTION_NAME,
TemporalEvaluatorDayOfQuarter.FUNCTION_NAME, "abs", "add", "div", "mult", "sub", "log", "pow", TemporalEvaluatorDayOfQuarter.FUNCTION_NAME, "abs", "add", "div", "mult", "sub", "log", "pow",
"mod", "ceil", "floor", "sin", "asin", "sinh", "cos", "acos", "cosh", "tan", "atan", "tanh", "round", "sqrt", "mod", "ceil", "floor", "sin", "asin", "sinh", "cos", "acos", "cosh", "tan", "atan", "tanh", "round", "sqrt",
"cbrt", "coalesce", "uuid", "if", "convert"}; "cbrt", "coalesce", "uuid", "if", "convert", "valueAt"};
@Test @Test
public void testLang() { public void testLang() {
@ -85,7 +85,7 @@ public class TestLang extends LuceneTestCase {
assertTrue("Testing function:"+func, registeredFunctions.containsKey(func)); assertTrue("Testing function:"+func, registeredFunctions.containsKey(func));
} }
//Check that ech function that is registered is expected. //Check that each function that is registered is expected.
Set<String> keys = registeredFunctions.keySet(); Set<String> keys = registeredFunctions.keySet();
for(String key : keys) { for(String key : keys) {
assertTrue("Testing key:"+key, functions.contains(key)); assertTrue("Testing key:"+key, functions.contains(key));

View File

@ -2354,6 +2354,31 @@ public class MathExpressionTest extends SolrCloudTestCase {
} }
} }
@Test
public void testValueAt() throws Exception {
String cexpr = "let(echo=true, " +
" b=array(1,2,3,4), " +
" c=matrix(array(5,6,7), " +
" array(8,9,10)), " +
" d=valueAt(b, 3)," +
" e=valueAt(c, 1, 0))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Number value1 = (Number)tuples.get(0).get("d");
Number value2 = (Number)tuples.get(0).get("e");
assertEquals(value1.intValue(), 4);
assertEquals(value2.intValue(), 8);
}
@Test @Test
public void testBetaDistribution() throws Exception { public void testBetaDistribution() throws Exception {
String cexpr = "let(a=sample(betaDistribution(1, 5), 50000), b=hist(a, 11), c=col(b, N))"; String cexpr = "let(a=sample(betaDistribution(1, 5), 50000), b=hist(a, 11), c=col(b, N))";