diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java index fdbb875c9ca..067bc8407b1 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java @@ -234,6 +234,7 @@ public class Lang { .withFunctionName("matrixMult", MatrixMultiplyEvaluator.class) .withFunctionName("bicubicSpline", BicubicSplineEvaluator.class) .withFunctionName("valueAt", ValueAtEvaluator.class) + .withFunctionName("memset", MemsetEvaluator.class) // Boolean Stream Evaluators diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MemsetEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MemsetEvaluator.java new file mode 100644 index 00000000000..e8ad9407a86 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MemsetEvaluator.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; + +import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.stream.StreamContext; +import org.apache.solr.client.solrj.io.stream.TupleStream; +import org.apache.solr.client.solrj.io.stream.expr.Expressible; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + + +/** + * The MemsetEvaluator reads a TupleStream and copies the values from specific + * fields into arrays that are bound to variable names in a map. The LetStream looks specifically + * for the MemsetEvaluator and makes the variables visible to other functions. + **/ + +public class MemsetEvaluator extends RecursiveEvaluator { + protected static final long serialVersionUID = 1L; + + private TupleStream in; + private String[] cols; + private String[] vars; + private int size = -1; + + public MemsetEvaluator(StreamExpression expression, StreamFactory factory) throws IOException { + super(expression, factory); + + /* + * Instantiate and validate all the parameters + */ + + List streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class); + StreamExpressionNamedParameter colsExpression = factory.getNamedOperand(expression, "cols"); + StreamExpressionNamedParameter varsExpression = factory.getNamedOperand(expression, "vars"); + StreamExpressionNamedParameter sizeExpression = factory.getNamedOperand(expression, "size"); + + if(1 != streamExpressions.size()){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting a single stream but found %d",expression, streamExpressions.size())); + } + + if(null == colsExpression || !(colsExpression.getParameter() instanceof StreamExpressionValue)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting single 'cols' parameter listing fields to sort over but didn't find one",expression)); + } + + if(null == varsExpression || !(varsExpression.getParameter() instanceof StreamExpressionValue)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting single 'vars' parameter listing fields to sort over but didn't find one",expression)); + } + + if(null != sizeExpression) { + StreamExpressionValue sizeExpressionValue = (StreamExpressionValue)sizeExpression.getParameter(); + String sizeString = sizeExpressionValue.getValue(); + size = Integer.parseInt(sizeString); + } + + in = factory.constructStream(streamExpressions.get(0)); + + StreamExpressionValue colsExpressionValue = (StreamExpressionValue)colsExpression.getParameter(); + StreamExpressionValue varsExpressionValue = (StreamExpressionValue)varsExpression.getParameter(); + String colsString = colsExpressionValue.getValue(); + String varsString = varsExpressionValue.getValue(); + + vars = varsString.split(","); + cols = colsString.split(","); + + if(cols.length != vars.length) { + throw new IOException("The cols and vars lists must be the same size"); + } + + for(int i=0; i ignoredNamedParameters) throws IOException { + super(expression, factory, ignoredNamedParameters); + } + + public void setStreamContext(StreamContext streamContext) { + this.streamContext = streamContext; + } + + @Override + public Object evaluate(Tuple tuple) throws IOException { + + /* + * Read all the tuples from the underlying stream and + * load specific fields into arrays. Then return + * a map with the variables names bound to the arrays. + */ + + try { + in.setStreamContext(streamContext); + in.open(); + Map> arrays = new HashMap(); + + //Initialize the variables + for(String var : vars) { + if(size > -1) { + arrays.put(var, new ArrayList(size)); + } else { + arrays.put(var, new ArrayList()); + } + } + + int count = 0; + + while (true) { + Tuple t = in.read(); + if (t.EOF) { + break; + } + + if(size == -1 || count < size) { + for (int i = 0; i < cols.length; i++) { + String col = cols[i]; + String var = vars[i]; + List array = arrays.get(var); + Number number = (Number) t.get(col); + array.add(number); + } + } + ++count; + } + + return arrays; + } catch (UncheckedIOException e) { + throw e.getCause(); + } finally { + in.close(); + } + } + + @Override + public Object doWork(Object... values) throws IOException { + // Nothing to do here + throw new IOException("This call should never occur"); + } +} + diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/LetStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/LetStream.java index 8bb12a530b2..e88eaf6e2b7 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/LetStream.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/LetStream.java @@ -27,6 +27,7 @@ import java.util.HashSet; import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.comp.StreamComparator; +import org.apache.solr.client.solrj.io.eval.MemsetEvaluator; import org.apache.solr.client.solrj.io.eval.StreamEvaluator; import org.apache.solr.client.solrj.io.stream.expr.Explanation; import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType; @@ -183,12 +184,18 @@ public class LetStream extends TupleStream implements Expressible { } } else { //Add the data from the StreamContext to a tuple. - //Let the evaluator work from this tuple. + //Let the evaluator works from this tuple. //This will allow columns to be created from tuples already in the StreamContext. Tuple eTuple = new Tuple(lets); StreamEvaluator evaluator = (StreamEvaluator)o; + evaluator.setStreamContext(streamContext); Object eo = evaluator.evaluate(eTuple); - lets.put(name, eo); + if(evaluator instanceof MemsetEvaluator) { + Map mem = (Map)eo; + lets.putAll(mem); + } else { + lets.put(name, eo); + } } } stream.open(); diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java index 87f5c4611e2..a98db517af0 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java @@ -68,7 +68,7 @@ public class TestLang extends LuceneTestCase { TemporalEvaluatorEpoch.FUNCTION_NAME, TemporalEvaluatorWeek.FUNCTION_NAME, TemporalEvaluatorQuarter.FUNCTION_NAME, TemporalEvaluatorDayOfQuarter.FUNCTION_NAME, "abs", "add", "div", "mult", "sub", "log", "pow", "mod", "ceil", "floor", "sin", "asin", "sinh", "cos", "acos", "cosh", "tan", "atan", "tanh", "round", "sqrt", - "cbrt", "coalesce", "uuid", "if", "convert", "valueAt"}; + "cbrt", "coalesce", "uuid", "if", "convert", "valueAt", "memset"}; @Test public void testLang() { diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java index 07570a9dbf5..0cf48844a0f 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.solr.client.solrj.io.stream; import java.io.IOException; @@ -205,6 +206,111 @@ public class MathExpressionTest extends SolrCloudTestCase { } } + @Test + public void testMemset() throws Exception { + String expr = "let(echo=\"b, c\"," + + " a=memset(list(tuple(field1=val(1), field2=val(10)), tuple(field1=val(2), field2=val(20))), " + + " cols=\"field1, field2\", " + + " vars=\"f1, f2\")," + + " b=add(f1)," + + " c=add(f2))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", expr); + paramsLoc.set("qt", "/stream"); + + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString() + "/" + COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertEquals(tuples.size(), 1); + Number f1 = (Number)tuples.get(0).get("b"); + assertEquals(f1.doubleValue(), 3, 0.0); + + Number f2 = (Number)tuples.get(0).get("c"); + assertEquals(f2.doubleValue(), 30, 0.0); + } + + @Test + public void testMemsetSize() throws Exception { + String expr = "let(echo=\"b, c\"," + + " a=memset(list(tuple(field1=val(1), field2=val(10)), tuple(field1=val(2), field2=val(20))), " + + " cols=\"field1, field2\", " + + " vars=\"f1, f2\"," + + " size=1)," + + " b=add(f1)," + + " c=add(f2))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", expr); + paramsLoc.set("qt", "/stream"); + + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString() + "/" + COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertEquals(tuples.size(), 1); + Number f1 = (Number)tuples.get(0).get("b"); + assertEquals(f1.doubleValue(), 1, 0.0); + + Number f2 = (Number)tuples.get(0).get("c"); + assertEquals(f2.doubleValue(), 10, 0.0); + } + + @Test + public void testMemsetTimeSeries() throws Exception { + UpdateRequest updateRequest = new UpdateRequest(); + + int i=0; + while(i<50) { + updateRequest.add(id, "id_"+(++i),"test_dt", getDateString("2016", "5", "1"), "price_f", "400.00"); + } + + while(i<100) { + updateRequest.add(id, "id_"+(++i),"test_dt", getDateString("2015", "5", "1"), "price_f", "300.0"); + } + + while(i<150) { + updateRequest.add(id, "id_"+(++i),"test_dt", getDateString("2014", "5", "1"), "price_f", "500.0"); + } + + while(i<250) { + updateRequest.add(id, "id_"+(++i),"test_dt", getDateString("2013", "5", "1"), "price_f", "100.00"); + } + + updateRequest.commit(cluster.getSolrClient(), COLLECTIONORALIAS); + + String expr = "memset(timeseries("+COLLECTIONORALIAS+", " + + " q=\"*:*\", " + + " start=\"2013-01-01T01:00:00.000Z\", " + + " end=\"2016-12-01T01:00:00.000Z\", " + + " gap=\"+1YEAR\", " + + " field=\"test_dt\", " + + " count(*)), " + + " cols=\"count(*)\"," + + " vars=\"a\")"; + + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", expr); + paramsLoc.set("qt", "/stream"); + + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + Map> mem = (Map)tuples.get(0).get("return-value"); + List array = mem.get("a"); + assertEquals(array.get(0).intValue(), 100); + assertEquals(array.get(1).intValue(), 50); + assertEquals(array.get(2).intValue(), 50); + assertEquals(array.get(3).intValue(), 50); + } + @Test public void testHist() throws Exception { String expr = "hist(sequence(100, 0, 1), 10)";