From d4f612368dfdb4599bf42e42bbaf4ab481490266 Mon Sep 17 00:00:00 2001 From: Joel Bernstein <jbernste@apache.org> Date: Sun, 29 Sep 2019 19:00:30 -0400 Subject: [PATCH] SOLR-13632: Support integral plots, cosine distance and string truncation with math expressions --- .../org/apache/solr/client/solrj/io/Lang.java | 4 +- .../io/eval/CosineDistanceEvaluator.java | 60 ++++++++++++++ .../io/eval/CosineSimilarityEvaluator.java | 6 +- .../solrj/io/eval/DerivativeEvaluator.java | 14 +++- .../solrj/io/eval/IntegrateEvaluator.java | 59 +++++++++----- .../solrj/io/eval/TopFeaturesEvaluator.java | 22 ++++-- .../client/solrj/io/eval/TruncEvaluator.java | 53 +++++++++++++ .../apache/solr/client/solrj/io/TestLang.java | 4 +- .../solrj/io/stream/MathExpressionTest.java | 79 +++++++++++++++++-- 9 files changed, 261 insertions(+), 40 deletions(-) create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CosineDistanceEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TruncEvaluator.java diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java index eed6b872a00..bd3710fbfe2 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java @@ -207,7 +207,7 @@ public class Lang { .withFunctionName("ttest", TTestEvaluator.class) .withFunctionName("pairedTtest", PairedTTestEvaluator.class) .withFunctionName("multiVariateNormalDistribution", MultiVariateNormalDistributionEvaluator.class) - .withFunctionName("integrate", IntegrateEvaluator.class) + .withFunctionName("integral", IntegrateEvaluator.class) .withFunctionName("density", DensityEvaluator.class) .withFunctionName("mannWhitney", MannWhitneyUEvaluator.class) .withFunctionName("sumSq", SumSqEvaluator.class) @@ -300,6 +300,8 @@ public class Lang { .withFunctionName("upper", UpperEvaluator.class) .withFunctionName("split", SplitEvaluator.class) .withFunctionName("trim", TrimEvaluator.class) + .withFunctionName("cosine", CosineDistanceEvaluator.class) + .withFunctionName("trunc", TruncEvaluator.class) // Boolean Stream Evaluators diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CosineDistanceEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CosineDistanceEvaluator.java new file mode 100644 index 00000000000..564c7348a0e --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CosineDistanceEvaluator.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.List; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.util.Precision; +import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class CosineDistanceEvaluator extends RecursiveEvaluator { + protected static final long serialVersionUID = 1L; + + public CosineDistanceEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + } + + public CosineDistanceEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{ + super(expression, factory, ignoredNamedParameters); + } + + @Override + public Object evaluate(Tuple tuple) throws IOException { + return new CosineDistance(); + } + + @Override + public Object doWork(Object... values) throws IOException { + // Nothing to do here + throw new IOException("This call should never occur"); + } + + public static class CosineDistance implements DistanceMeasure { + + private static final long serialVersionUID = -9108154600539125566L; + + public double compute(double[] v1, double[] v2) throws DimensionMismatchException { + return Precision.round(1-Math.abs(CosineSimilarityEvaluator.cosineSimilarity(v1, v2)), 8); + } + } +} \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CosineSimilarityEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CosineSimilarityEvaluator.java index 2b21ac8ff9c..07823c05543 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CosineSimilarityEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CosineSimilarityEvaluator.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.List; import java.util.Locale; +import org.apache.commons.math3.util.Precision; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; @@ -51,7 +52,7 @@ public class CosineSimilarityEvaluator extends RecursiveNumericEvaluator impleme return cosineSimilarity(d1, d2); } - private double cosineSimilarity(double[] vectorA, double[] vectorB) { + public static double cosineSimilarity(double[] vectorA, double[] vectorB) { double dotProduct = 0.0; double normA = 0.0; double normB = 0.0; @@ -60,7 +61,8 @@ public class CosineSimilarityEvaluator extends RecursiveNumericEvaluator impleme normA += Math.pow(vectorA[i], 2); normB += Math.pow(vectorB[i], 2); } - return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + double d = dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + return Precision.round(d, 8); } } diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/DerivativeEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/DerivativeEvaluator.java index 183a47babf0..895d3b5544f 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/DerivativeEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/DerivativeEvaluator.java @@ -21,6 +21,7 @@ import java.util.Locale; import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction; import org.apache.commons.math3.analysis.UnivariateFunction; +import org.apache.commons.math3.analysis.interpolation.AkimaSplineInterpolator; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; @@ -42,12 +43,17 @@ public class DerivativeEvaluator extends RecursiveObjectEvaluator implements One } VectorFunction vectorFunction = (VectorFunction) value; + + DifferentiableUnivariateFunction func = null; + double[] x = (double[])vectorFunction.getFromContext("x"); + if(!(vectorFunction.getFunction() instanceof DifferentiableUnivariateFunction)) { - throw new IOException("Cannot evaluate derivative from parameter."); + double[] y = (double[])vectorFunction.getFromContext("y"); + func = new AkimaSplineInterpolator().interpolate(x, y); + } else { + func = (DifferentiableUnivariateFunction) vectorFunction.getFunction(); } - DifferentiableUnivariateFunction func = (DifferentiableUnivariateFunction)vectorFunction.getFunction(); - double[] x = (double[])vectorFunction.getFromContext("x"); UnivariateFunction derfunc = func.derivative(); double[] dvalues = new double[x.length]; for(int i=0; i<x.length; i++) { @@ -56,7 +62,7 @@ public class DerivativeEvaluator extends RecursiveObjectEvaluator implements One VectorFunction vf = new VectorFunction(derfunc, dvalues); vf.addToContext("x", x); - vf.addToContext("y", vectorFunction.getFromContext("y")); + vf.addToContext("y", dvalues); return vf; } diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/IntegrateEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/IntegrateEvaluator.java index 277748c8dd4..a09a131e1f9 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/IntegrateEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/IntegrateEvaluator.java @@ -17,10 +17,14 @@ package org.apache.solr.client.solrj.io.eval; import java.io.IOException; +import java.util.ArrayList; import java.util.Locale; import org.apache.commons.math3.analysis.UnivariateFunction; +import org.apache.commons.math3.analysis.integration.MidPointIntegrator; import org.apache.commons.math3.analysis.integration.RombergIntegrator; +import org.apache.commons.math3.analysis.integration.SimpsonIntegrator; +import org.apache.commons.math3.analysis.integration.TrapezoidIntegrator; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; @@ -34,8 +38,8 @@ public class IntegrateEvaluator extends RecursiveObjectEvaluator implements Many @Override public Object doWork(Object... values) throws IOException { - if(values.length != 3) { - throw new IOException("The integrate function requires 3 parameters"); + if(values.length > 3) { + throw new IOException("The integrate function requires at most 3 parameters"); } if (!(values[0] instanceof VectorFunction)) { @@ -43,28 +47,45 @@ public class IntegrateEvaluator extends RecursiveObjectEvaluator implements Many } VectorFunction vectorFunction = (VectorFunction) values[0]; - if(!(vectorFunction.getFunction() instanceof UnivariateFunction)) { + if (!(vectorFunction.getFunction() instanceof UnivariateFunction)) { throw new IOException("Cannot evaluate integral from parameter."); } - Number min = null; - Number max = null; + UnivariateFunction func = (UnivariateFunction) vectorFunction.getFunction(); - if(values[1] instanceof Number) { - min = (Number) values[1]; + if(values.length == 3) { + + + Number min = null; + Number max = null; + + if (values[1] instanceof Number) { + min = (Number) values[1]; + } else { + throw new IOException("The second parameter of the integrate function must be a number"); + } + + if (values[2] instanceof Number) { + max = (Number) values[2]; + } else { + throw new IOException("The third parameter of the integrate function must be a number"); + } + + RombergIntegrator rombergIntegrator = new RombergIntegrator(); + return rombergIntegrator.integrate(5000, func, min.doubleValue(), max.doubleValue()); } else { - throw new IOException("The second parameter of the integrate function must be a number"); + RombergIntegrator integrator = new RombergIntegrator(); + + double[] x = (double[])vectorFunction.getFromContext("x"); + double[] y = (double[])vectorFunction.getFromContext("y"); + ArrayList<Number> out = new ArrayList(); + out.add(0); + for(int i=1; i<x.length; i++) { + out.add(integrator.integrate(5000, func, x[0], x[i])); + } + + return out; + } - - if(values[2] instanceof Number ) { - max = (Number) values[2]; - } else { - throw new IOException("The third parameter of the integrate function must be a number"); - } - - UnivariateFunction func = (UnivariateFunction)vectorFunction.getFunction(); - - RombergIntegrator rombergIntegrator = new RombergIntegrator(); - return rombergIntegrator.integrate(5000, func, min.doubleValue(), max.doubleValue()); } } diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TopFeaturesEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TopFeaturesEvaluator.java index e2100b1fbde..e2dddfb8ff7 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TopFeaturesEvaluator.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TopFeaturesEvaluator.java @@ -73,9 +73,11 @@ public class TopFeaturesEvaluator extends RecursiveObjectEvaluator implements Tw private List<Integer> getMaxIndexes(double[] values, int k) { TreeSet<Pair> set = new TreeSet(); for(int i=0; i<values.length; i++) { - set.add(new Pair(i, values[i])); - if(set.size() > k) { - set.pollFirst(); + if(values[i] > 0){ + set.add(new Pair(i, values[i])); + if (set.size() > k) { + set.pollFirst(); + } } } @@ -89,16 +91,22 @@ public class TopFeaturesEvaluator extends RecursiveObjectEvaluator implements Tw public static class Pair implements Comparable<Pair> { - private int index; + private Integer index; private Double value; - public Pair(int index, Number value) { - this.index = index; + public Pair(int _index, Number value) { + this.index = _index; this.value = value.doubleValue(); } public int compareTo(Pair pair) { - return value.compareTo(pair.value); + + int c = value.compareTo(pair.value); + if(c==0) { + return index.compareTo(pair.index); + } else { + return c; + } } public int getIndex() { diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TruncEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TruncEvaluator.java new file mode 100644 index 00000000000..0e4ebaca488 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TruncEvaluator.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.stream.Collectors; + +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class TruncEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker { + protected static final long serialVersionUID = 1L; + + public TruncEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + + if(2 != containedEvaluators.size()){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 2 values but found %d",expression,containedEvaluators.size())); + } + } + + @Override + public Object doWork(Object value1, Object value2){ + if(null == value1){ + return null; + } + + int endIndex = ((Number)value2).intValue(); + + if(value1 instanceof List){ + return ((List<?>)value1).stream().map(innerValue -> doWork(innerValue, endIndex)).collect(Collectors.toList()); + } + else { + return value1.toString().substring(0, endIndex); + } + } +} diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java index 0435ed55f76..2e427002c9d 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java @@ -57,7 +57,7 @@ public class TestLang extends SolrTestCase { "triangularDistribution", "precision", "minMaxScale", "markovChain", "grandSum", "scalarAdd", "scalarSubtract", "scalarMultiply", "scalarDivide", "sumRows", "sumColumns", "diff", "corrPValues", "normalizeSum", "geometricDistribution", "olsRegress", - "derivative", "spline", "ttest", "pairedTtest", "multiVariateNormalDistribution", "integrate", + "derivative", "spline", "ttest", "pairedTtest", "multiVariateNormalDistribution", "integral", "density", "mannWhitney", "sumSq", "akima", "lerp", "chiSquareDataSet", "gtestDataSet", "termVectors", "getColumnLabels", "getRowLabels", "getAttribute", "kmeans", "getCentroids", "getCluster", "topFeatures", "featureSelect", "rowAt", "colAt", "setColumnLabels", @@ -77,7 +77,7 @@ public class TestLang extends SolrTestCase { "getSupportPoints", "pairSort", "log10", "plist", "recip", "pivot", "ltrim", "rtrim", "export", "zplot", "natural", "repeat", "movingMAD", "hashRollup", "noop", "var", "stddev", "recNum", "isNull", "notNull", "matches", "projectToBorder", "double", "long", "parseCSV", "parseTSV", "dateTime", - "split", "upper", "trim", "lower"}; + "split", "upper", "trim", "lower", "trunc", "cosine"}; @Test public void testLang() { diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java index 890d0d33b33..f69f369290d 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java @@ -229,6 +229,27 @@ public class MathExpressionTest extends SolrCloudTestCase { assertEquals(s2, "c-d-hello"); } + + @Test + public void testTrunc() throws Exception { + String expr = " select(list(tuple(field1=\"abcde\", field2=\"012345\")), trunc(field1, 2) as field3, trunc(field2, 4) as field4)"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", expr); + paramsLoc.set("qt", "/stream"); + + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString() + "/" + COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List<Tuple> tuples = getTuples(solrStream); + assertEquals(tuples.size(), 1); + String s1 = tuples.get(0).getString("field3"); + assertEquals(s1, "ab"); + String s2 = tuples.get(0).getString("field4"); + assertEquals(s2, "0123"); + } + @Test public void testUpperLowerSingle() throws Exception { String expr = " select(list(tuple(field1=\"a\", field2=\"C\")), upper(field1) as field3, lower(field2) as field4)"; @@ -249,6 +270,28 @@ public class MathExpressionTest extends SolrCloudTestCase { assertEquals(s2, "c"); } + + @Test + public void testTruncArray() throws Exception { + String expr = " select(list(tuple(field1=array(\"aaaa\",\"bbbb\",\"cccc\"))), trunc(field1, 3) as field2)"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", expr); + paramsLoc.set("qt", "/stream"); + + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString() + "/" + COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List<Tuple> tuples = getTuples(solrStream); + assertEquals(tuples.size(), 1); + List<String> l1 = (List<String>)tuples.get(0).get("field2"); + assertEquals(l1.get(0), "aaa"); + assertEquals(l1.get(1), "bbb"); + assertEquals(l1.get(2), "ccc"); + + } + @Test public void testUpperLowerArray() throws Exception { String expr = " select(list(tuple(field1=array(\"a\",\"b\",\"c\"), field2=array(\"X\",\"Y\",\"Z\"))), upper(field1) as field3, lower(field2) as field4)"; @@ -722,6 +765,27 @@ public class MathExpressionTest extends SolrCloudTestCase { assertTrue(tuples.get(0).getDouble("cov").equals(-625.0D)); } + @Test + public void testCosineDistance() throws Exception { + String cexpr = "let(echo=true, " + + "a=array(1,2,3,4)," + + "b=array(10, 20, 30, 45), " + + "c=distance(a, b, cosine()), " + + ")"; + + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString() + "/" + COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List<Tuple> tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + Number d = (Number) tuples.get(0).get("c"); + assertEquals(d.doubleValue(), 0.0017046159, 0.0001); + } + @Test public void testDistance() throws Exception { String cexpr = "let(echo=true, " + @@ -3343,7 +3407,7 @@ public class MathExpressionTest extends SolrCloudTestCase { List<Tuple> tuples = getTuples(solrStream); assertTrue(tuples.size() == 1); Number cs = (Number)tuples.get(0).get("return-value"); - assertTrue(cs.doubleValue() == 0.9838197164968291); + assertEquals(cs.doubleValue(),0.9838197164968291, .00000001); } @Test @@ -4085,9 +4149,10 @@ public class MathExpressionTest extends SolrCloudTestCase { String cexpr = "let(echo=true, " + "a=sequence(50, 1, 0), " + "b=spline(a), " + - "c=integrate(b, 0, 49), " + - "d=integrate(b, 0, 20), " + - "e=integrate(b, 20, 49))"; + "c=integral(b, 0, 49), " + + "d=integral(b, 0, 20), " + + "e=integral(b, 20, 49)," + + "f=integral(b))"; ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); paramsLoc.set("expr", cexpr); paramsLoc.set("qt", "/stream"); @@ -4103,6 +4168,9 @@ public class MathExpressionTest extends SolrCloudTestCase { assertEquals(integral.doubleValue(), 20, 0.0); integral = (Number)tuples.get(0).get("e"); assertEquals(integral.doubleValue(), 29, 0.0); + List<Number> integrals = (List<Number>)tuples.get(0).get("f"); + assertEquals(integrals.size(), 50); + assertEquals(integrals.get(49).intValue(), 49); } @Test @@ -4313,7 +4381,8 @@ public class MathExpressionTest extends SolrCloudTestCase { } - @Test + + @Test public void testLerp() throws Exception { String cexpr = "let(echo=true," + " a=array(0,1,2,3,4,5,6,7), " +