SOLR-13632: Support integral plots, cosine distance and string truncation with math expressions

This commit is contained in:
Joel Bernstein 2019-09-29 19:00:30 -04:00
parent 4f89987141
commit d4f612368d
9 changed files with 261 additions and 40 deletions

View File

@ -207,7 +207,7 @@ public class Lang {
.withFunctionName("ttest", TTestEvaluator.class)
.withFunctionName("pairedTtest", PairedTTestEvaluator.class)
.withFunctionName("multiVariateNormalDistribution", MultiVariateNormalDistributionEvaluator.class)
.withFunctionName("integrate", IntegrateEvaluator.class)
.withFunctionName("integral", IntegrateEvaluator.class)
.withFunctionName("density", DensityEvaluator.class)
.withFunctionName("mannWhitney", MannWhitneyUEvaluator.class)
.withFunctionName("sumSq", SumSqEvaluator.class)
@ -300,6 +300,8 @@ public class Lang {
.withFunctionName("upper", UpperEvaluator.class)
.withFunctionName("split", SplitEvaluator.class)
.withFunctionName("trim", TrimEvaluator.class)
.withFunctionName("cosine", CosineDistanceEvaluator.class)
.withFunctionName("trunc", TruncEvaluator.class)
// Boolean Stream Evaluators

View File

@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.util.Precision;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class CosineDistanceEvaluator extends RecursiveEvaluator {
protected static final long serialVersionUID = 1L;
public CosineDistanceEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
public CosineDistanceEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
super(expression, factory, ignoredNamedParameters);
}
@Override
public Object evaluate(Tuple tuple) throws IOException {
return new CosineDistance();
}
@Override
public Object doWork(Object... values) throws IOException {
// Nothing to do here
throw new IOException("This call should never occur");
}
public static class CosineDistance implements DistanceMeasure {
private static final long serialVersionUID = -9108154600539125566L;
public double compute(double[] v1, double[] v2) throws DimensionMismatchException {
return Precision.round(1-Math.abs(CosineSimilarityEvaluator.cosineSimilarity(v1, v2)), 8);
}
}
}

View File

@ -20,6 +20,7 @@ import java.io.IOException;
import java.util.List;
import java.util.Locale;
import org.apache.commons.math3.util.Precision;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
@ -51,7 +52,7 @@ public class CosineSimilarityEvaluator extends RecursiveNumericEvaluator impleme
return cosineSimilarity(d1, d2);
}
private double cosineSimilarity(double[] vectorA, double[] vectorB) {
public static double cosineSimilarity(double[] vectorA, double[] vectorB) {
double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;
@ -60,7 +61,8 @@ public class CosineSimilarityEvaluator extends RecursiveNumericEvaluator impleme
normA += Math.pow(vectorA[i], 2);
normB += Math.pow(vectorB[i], 2);
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
double d = dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
return Precision.round(d, 8);
}
}

View File

@ -21,6 +21,7 @@ import java.util.Locale;
import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.interpolation.AkimaSplineInterpolator;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
@ -42,12 +43,17 @@ public class DerivativeEvaluator extends RecursiveObjectEvaluator implements One
}
VectorFunction vectorFunction = (VectorFunction) value;
DifferentiableUnivariateFunction func = null;
double[] x = (double[])vectorFunction.getFromContext("x");
if(!(vectorFunction.getFunction() instanceof DifferentiableUnivariateFunction)) {
throw new IOException("Cannot evaluate derivative from parameter.");
double[] y = (double[])vectorFunction.getFromContext("y");
func = new AkimaSplineInterpolator().interpolate(x, y);
} else {
func = (DifferentiableUnivariateFunction) vectorFunction.getFunction();
}
DifferentiableUnivariateFunction func = (DifferentiableUnivariateFunction)vectorFunction.getFunction();
double[] x = (double[])vectorFunction.getFromContext("x");
UnivariateFunction derfunc = func.derivative();
double[] dvalues = new double[x.length];
for(int i=0; i<x.length; i++) {
@ -56,7 +62,7 @@ public class DerivativeEvaluator extends RecursiveObjectEvaluator implements One
VectorFunction vf = new VectorFunction(derfunc, dvalues);
vf.addToContext("x", x);
vf.addToContext("y", vectorFunction.getFromContext("y"));
vf.addToContext("y", dvalues);
return vf;
}

View File

@ -17,10 +17,14 @@
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Locale;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.integration.MidPointIntegrator;
import org.apache.commons.math3.analysis.integration.RombergIntegrator;
import org.apache.commons.math3.analysis.integration.SimpsonIntegrator;
import org.apache.commons.math3.analysis.integration.TrapezoidIntegrator;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
@ -34,8 +38,8 @@ public class IntegrateEvaluator extends RecursiveObjectEvaluator implements Many
@Override
public Object doWork(Object... values) throws IOException {
if(values.length != 3) {
throw new IOException("The integrate function requires 3 parameters");
if(values.length > 3) {
throw new IOException("The integrate function requires at most 3 parameters");
}
if (!(values[0] instanceof VectorFunction)) {
@ -43,28 +47,45 @@ public class IntegrateEvaluator extends RecursiveObjectEvaluator implements Many
}
VectorFunction vectorFunction = (VectorFunction) values[0];
if(!(vectorFunction.getFunction() instanceof UnivariateFunction)) {
if (!(vectorFunction.getFunction() instanceof UnivariateFunction)) {
throw new IOException("Cannot evaluate integral from parameter.");
}
UnivariateFunction func = (UnivariateFunction) vectorFunction.getFunction();
if(values.length == 3) {
Number min = null;
Number max = null;
if(values[1] instanceof Number) {
if (values[1] instanceof Number) {
min = (Number) values[1];
} else {
throw new IOException("The second parameter of the integrate function must be a number");
}
if(values[2] instanceof Number ) {
if (values[2] instanceof Number) {
max = (Number) values[2];
} else {
throw new IOException("The third parameter of the integrate function must be a number");
}
UnivariateFunction func = (UnivariateFunction)vectorFunction.getFunction();
RombergIntegrator rombergIntegrator = new RombergIntegrator();
return rombergIntegrator.integrate(5000, func, min.doubleValue(), max.doubleValue());
} else {
RombergIntegrator integrator = new RombergIntegrator();
double[] x = (double[])vectorFunction.getFromContext("x");
double[] y = (double[])vectorFunction.getFromContext("y");
ArrayList<Number> out = new ArrayList();
out.add(0);
for(int i=1; i<x.length; i++) {
out.add(integrator.integrate(5000, func, x[0], x[i]));
}
return out;
}
}
}

View File

@ -73,11 +73,13 @@ public class TopFeaturesEvaluator extends RecursiveObjectEvaluator implements Tw
private List<Integer> getMaxIndexes(double[] values, int k) {
TreeSet<Pair> set = new TreeSet();
for(int i=0; i<values.length; i++) {
if(values[i] > 0){
set.add(new Pair(i, values[i]));
if(set.size() > k) {
if (set.size() > k) {
set.pollFirst();
}
}
}
List<Integer> top = new ArrayList(k);
while(set.size() > 0) {
@ -89,16 +91,22 @@ public class TopFeaturesEvaluator extends RecursiveObjectEvaluator implements Tw
public static class Pair implements Comparable<Pair> {
private int index;
private Integer index;
private Double value;
public Pair(int index, Number value) {
this.index = index;
public Pair(int _index, Number value) {
this.index = _index;
this.value = value.doubleValue();
}
public int compareTo(Pair pair) {
return value.compareTo(pair.value);
int c = value.compareTo(pair.value);
if(c==0) {
return index.compareTo(pair.index);
} else {
return c;
}
}
public int getIndex() {

View File

@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class TruncEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
protected static final long serialVersionUID = 1L;
public TruncEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
if(2 != containedEvaluators.size()){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 2 values but found %d",expression,containedEvaluators.size()));
}
}
@Override
public Object doWork(Object value1, Object value2){
if(null == value1){
return null;
}
int endIndex = ((Number)value2).intValue();
if(value1 instanceof List){
return ((List<?>)value1).stream().map(innerValue -> doWork(innerValue, endIndex)).collect(Collectors.toList());
}
else {
return value1.toString().substring(0, endIndex);
}
}
}

View File

@ -57,7 +57,7 @@ public class TestLang extends SolrTestCase {
"triangularDistribution", "precision", "minMaxScale", "markovChain", "grandSum",
"scalarAdd", "scalarSubtract", "scalarMultiply", "scalarDivide", "sumRows",
"sumColumns", "diff", "corrPValues", "normalizeSum", "geometricDistribution", "olsRegress",
"derivative", "spline", "ttest", "pairedTtest", "multiVariateNormalDistribution", "integrate",
"derivative", "spline", "ttest", "pairedTtest", "multiVariateNormalDistribution", "integral",
"density", "mannWhitney", "sumSq", "akima", "lerp", "chiSquareDataSet", "gtestDataSet",
"termVectors", "getColumnLabels", "getRowLabels", "getAttribute", "kmeans", "getCentroids",
"getCluster", "topFeatures", "featureSelect", "rowAt", "colAt", "setColumnLabels",
@ -77,7 +77,7 @@ public class TestLang extends SolrTestCase {
"getSupportPoints", "pairSort", "log10", "plist", "recip", "pivot", "ltrim", "rtrim", "export",
"zplot", "natural", "repeat", "movingMAD", "hashRollup", "noop", "var", "stddev", "recNum", "isNull",
"notNull", "matches", "projectToBorder", "double", "long", "parseCSV", "parseTSV", "dateTime",
"split", "upper", "trim", "lower"};
"split", "upper", "trim", "lower", "trunc", "cosine"};
@Test
public void testLang() {

View File

@ -229,6 +229,27 @@ public class MathExpressionTest extends SolrCloudTestCase {
assertEquals(s2, "c-d-hello");
}
@Test
public void testTrunc() throws Exception {
String expr = " select(list(tuple(field1=\"abcde\", field2=\"012345\")), trunc(field1, 2) as field3, trunc(field2, 4) as field4)";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", expr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString() + "/" + COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertEquals(tuples.size(), 1);
String s1 = tuples.get(0).getString("field3");
assertEquals(s1, "ab");
String s2 = tuples.get(0).getString("field4");
assertEquals(s2, "0123");
}
@Test
public void testUpperLowerSingle() throws Exception {
String expr = " select(list(tuple(field1=\"a\", field2=\"C\")), upper(field1) as field3, lower(field2) as field4)";
@ -249,6 +270,28 @@ public class MathExpressionTest extends SolrCloudTestCase {
assertEquals(s2, "c");
}
@Test
public void testTruncArray() throws Exception {
String expr = " select(list(tuple(field1=array(\"aaaa\",\"bbbb\",\"cccc\"))), trunc(field1, 3) as field2)";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", expr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString() + "/" + COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertEquals(tuples.size(), 1);
List<String> l1 = (List<String>)tuples.get(0).get("field2");
assertEquals(l1.get(0), "aaa");
assertEquals(l1.get(1), "bbb");
assertEquals(l1.get(2), "ccc");
}
@Test
public void testUpperLowerArray() throws Exception {
String expr = " select(list(tuple(field1=array(\"a\",\"b\",\"c\"), field2=array(\"X\",\"Y\",\"Z\"))), upper(field1) as field3, lower(field2) as field4)";
@ -722,6 +765,27 @@ public class MathExpressionTest extends SolrCloudTestCase {
assertTrue(tuples.get(0).getDouble("cov").equals(-625.0D));
}
@Test
public void testCosineDistance() throws Exception {
String cexpr = "let(echo=true, " +
"a=array(1,2,3,4)," +
"b=array(10, 20, 30, 45), " +
"c=distance(a, b, cosine()), " +
")";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString() + "/" + COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Number d = (Number) tuples.get(0).get("c");
assertEquals(d.doubleValue(), 0.0017046159, 0.0001);
}
@Test
public void testDistance() throws Exception {
String cexpr = "let(echo=true, " +
@ -3343,7 +3407,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Number cs = (Number)tuples.get(0).get("return-value");
assertTrue(cs.doubleValue() == 0.9838197164968291);
assertEquals(cs.doubleValue(),0.9838197164968291, .00000001);
}
@Test
@ -4085,9 +4149,10 @@ public class MathExpressionTest extends SolrCloudTestCase {
String cexpr = "let(echo=true, " +
"a=sequence(50, 1, 0), " +
"b=spline(a), " +
"c=integrate(b, 0, 49), " +
"d=integrate(b, 0, 20), " +
"e=integrate(b, 20, 49))";
"c=integral(b, 0, 49), " +
"d=integral(b, 0, 20), " +
"e=integral(b, 20, 49)," +
"f=integral(b))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
@ -4103,6 +4168,9 @@ public class MathExpressionTest extends SolrCloudTestCase {
assertEquals(integral.doubleValue(), 20, 0.0);
integral = (Number)tuples.get(0).get("e");
assertEquals(integral.doubleValue(), 29, 0.0);
List<Number> integrals = (List<Number>)tuples.get(0).get("f");
assertEquals(integrals.size(), 50);
assertEquals(integrals.get(49).intValue(), 49);
}
@Test
@ -4313,6 +4381,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
}
@Test
public void testLerp() throws Exception {
String cexpr = "let(echo=true," +