mirror of https://github.com/apache/lucene.git
SOLR-13632: Support integral plots, cosine distance and string truncation with math expressions
This commit is contained in:
parent
4f89987141
commit
d4f612368d
|
@ -207,7 +207,7 @@ public class Lang {
|
|||
.withFunctionName("ttest", TTestEvaluator.class)
|
||||
.withFunctionName("pairedTtest", PairedTTestEvaluator.class)
|
||||
.withFunctionName("multiVariateNormalDistribution", MultiVariateNormalDistributionEvaluator.class)
|
||||
.withFunctionName("integrate", IntegrateEvaluator.class)
|
||||
.withFunctionName("integral", IntegrateEvaluator.class)
|
||||
.withFunctionName("density", DensityEvaluator.class)
|
||||
.withFunctionName("mannWhitney", MannWhitneyUEvaluator.class)
|
||||
.withFunctionName("sumSq", SumSqEvaluator.class)
|
||||
|
@ -300,6 +300,8 @@ public class Lang {
|
|||
.withFunctionName("upper", UpperEvaluator.class)
|
||||
.withFunctionName("split", SplitEvaluator.class)
|
||||
.withFunctionName("trim", TrimEvaluator.class)
|
||||
.withFunctionName("cosine", CosineDistanceEvaluator.class)
|
||||
.withFunctionName("trunc", TruncEvaluator.class)
|
||||
|
||||
// Boolean Stream Evaluators
|
||||
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.math3.exception.DimensionMismatchException;
|
||||
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||
import org.apache.commons.math3.util.Precision;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class CosineDistanceEvaluator extends RecursiveEvaluator {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public CosineDistanceEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
public CosineDistanceEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
|
||||
super(expression, factory, ignoredNamedParameters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object evaluate(Tuple tuple) throws IOException {
|
||||
return new CosineDistance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object... values) throws IOException {
|
||||
// Nothing to do here
|
||||
throw new IOException("This call should never occur");
|
||||
}
|
||||
|
||||
public static class CosineDistance implements DistanceMeasure {
|
||||
|
||||
private static final long serialVersionUID = -9108154600539125566L;
|
||||
|
||||
public double compute(double[] v1, double[] v2) throws DimensionMismatchException {
|
||||
return Precision.round(1-Math.abs(CosineSimilarityEvaluator.cosineSimilarity(v1, v2)), 8);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ import java.io.IOException;
|
|||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.util.Precision;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
|
@ -51,7 +52,7 @@ public class CosineSimilarityEvaluator extends RecursiveNumericEvaluator impleme
|
|||
return cosineSimilarity(d1, d2);
|
||||
}
|
||||
|
||||
private double cosineSimilarity(double[] vectorA, double[] vectorB) {
|
||||
public static double cosineSimilarity(double[] vectorA, double[] vectorB) {
|
||||
double dotProduct = 0.0;
|
||||
double normA = 0.0;
|
||||
double normB = 0.0;
|
||||
|
@ -60,7 +61,8 @@ public class CosineSimilarityEvaluator extends RecursiveNumericEvaluator impleme
|
|||
normA += Math.pow(vectorA[i], 2);
|
||||
normB += Math.pow(vectorB[i], 2);
|
||||
}
|
||||
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
|
||||
double d = dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
|
||||
return Precision.round(d, 8);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import java.util.Locale;
|
|||
|
||||
import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction;
|
||||
import org.apache.commons.math3.analysis.UnivariateFunction;
|
||||
import org.apache.commons.math3.analysis.interpolation.AkimaSplineInterpolator;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
|
@ -42,12 +43,17 @@ public class DerivativeEvaluator extends RecursiveObjectEvaluator implements One
|
|||
}
|
||||
|
||||
VectorFunction vectorFunction = (VectorFunction) value;
|
||||
|
||||
DifferentiableUnivariateFunction func = null;
|
||||
double[] x = (double[])vectorFunction.getFromContext("x");
|
||||
|
||||
if(!(vectorFunction.getFunction() instanceof DifferentiableUnivariateFunction)) {
|
||||
throw new IOException("Cannot evaluate derivative from parameter.");
|
||||
double[] y = (double[])vectorFunction.getFromContext("y");
|
||||
func = new AkimaSplineInterpolator().interpolate(x, y);
|
||||
} else {
|
||||
func = (DifferentiableUnivariateFunction) vectorFunction.getFunction();
|
||||
}
|
||||
|
||||
DifferentiableUnivariateFunction func = (DifferentiableUnivariateFunction)vectorFunction.getFunction();
|
||||
double[] x = (double[])vectorFunction.getFromContext("x");
|
||||
UnivariateFunction derfunc = func.derivative();
|
||||
double[] dvalues = new double[x.length];
|
||||
for(int i=0; i<x.length; i++) {
|
||||
|
@ -56,7 +62,7 @@ public class DerivativeEvaluator extends RecursiveObjectEvaluator implements One
|
|||
|
||||
VectorFunction vf = new VectorFunction(derfunc, dvalues);
|
||||
vf.addToContext("x", x);
|
||||
vf.addToContext("y", vectorFunction.getFromContext("y"));
|
||||
vf.addToContext("y", dvalues);
|
||||
|
||||
return vf;
|
||||
}
|
||||
|
|
|
@ -17,10 +17,14 @@
|
|||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.analysis.UnivariateFunction;
|
||||
import org.apache.commons.math3.analysis.integration.MidPointIntegrator;
|
||||
import org.apache.commons.math3.analysis.integration.RombergIntegrator;
|
||||
import org.apache.commons.math3.analysis.integration.SimpsonIntegrator;
|
||||
import org.apache.commons.math3.analysis.integration.TrapezoidIntegrator;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
|
@ -34,8 +38,8 @@ public class IntegrateEvaluator extends RecursiveObjectEvaluator implements Many
|
|||
@Override
|
||||
public Object doWork(Object... values) throws IOException {
|
||||
|
||||
if(values.length != 3) {
|
||||
throw new IOException("The integrate function requires 3 parameters");
|
||||
if(values.length > 3) {
|
||||
throw new IOException("The integrate function requires at most 3 parameters");
|
||||
}
|
||||
|
||||
if (!(values[0] instanceof VectorFunction)) {
|
||||
|
@ -43,28 +47,45 @@ public class IntegrateEvaluator extends RecursiveObjectEvaluator implements Many
|
|||
}
|
||||
|
||||
VectorFunction vectorFunction = (VectorFunction) values[0];
|
||||
if(!(vectorFunction.getFunction() instanceof UnivariateFunction)) {
|
||||
if (!(vectorFunction.getFunction() instanceof UnivariateFunction)) {
|
||||
throw new IOException("Cannot evaluate integral from parameter.");
|
||||
}
|
||||
|
||||
UnivariateFunction func = (UnivariateFunction) vectorFunction.getFunction();
|
||||
|
||||
if(values.length == 3) {
|
||||
|
||||
|
||||
Number min = null;
|
||||
Number max = null;
|
||||
|
||||
if(values[1] instanceof Number) {
|
||||
if (values[1] instanceof Number) {
|
||||
min = (Number) values[1];
|
||||
} else {
|
||||
throw new IOException("The second parameter of the integrate function must be a number");
|
||||
}
|
||||
|
||||
if(values[2] instanceof Number ) {
|
||||
if (values[2] instanceof Number) {
|
||||
max = (Number) values[2];
|
||||
} else {
|
||||
throw new IOException("The third parameter of the integrate function must be a number");
|
||||
}
|
||||
|
||||
UnivariateFunction func = (UnivariateFunction)vectorFunction.getFunction();
|
||||
|
||||
RombergIntegrator rombergIntegrator = new RombergIntegrator();
|
||||
return rombergIntegrator.integrate(5000, func, min.doubleValue(), max.doubleValue());
|
||||
} else {
|
||||
RombergIntegrator integrator = new RombergIntegrator();
|
||||
|
||||
double[] x = (double[])vectorFunction.getFromContext("x");
|
||||
double[] y = (double[])vectorFunction.getFromContext("y");
|
||||
ArrayList<Number> out = new ArrayList();
|
||||
out.add(0);
|
||||
for(int i=1; i<x.length; i++) {
|
||||
out.add(integrator.integrate(5000, func, x[0], x[i]));
|
||||
}
|
||||
|
||||
return out;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -73,11 +73,13 @@ public class TopFeaturesEvaluator extends RecursiveObjectEvaluator implements Tw
|
|||
private List<Integer> getMaxIndexes(double[] values, int k) {
|
||||
TreeSet<Pair> set = new TreeSet();
|
||||
for(int i=0; i<values.length; i++) {
|
||||
if(values[i] > 0){
|
||||
set.add(new Pair(i, values[i]));
|
||||
if(set.size() > k) {
|
||||
if (set.size() > k) {
|
||||
set.pollFirst();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
List<Integer> top = new ArrayList(k);
|
||||
while(set.size() > 0) {
|
||||
|
@ -89,16 +91,22 @@ public class TopFeaturesEvaluator extends RecursiveObjectEvaluator implements Tw
|
|||
|
||||
public static class Pair implements Comparable<Pair> {
|
||||
|
||||
private int index;
|
||||
private Integer index;
|
||||
private Double value;
|
||||
|
||||
public Pair(int index, Number value) {
|
||||
this.index = index;
|
||||
public Pair(int _index, Number value) {
|
||||
this.index = _index;
|
||||
this.value = value.doubleValue();
|
||||
}
|
||||
|
||||
public int compareTo(Pair pair) {
|
||||
return value.compareTo(pair.value);
|
||||
|
||||
int c = value.compareTo(pair.value);
|
||||
if(c==0) {
|
||||
return index.compareTo(pair.index);
|
||||
} else {
|
||||
return c;
|
||||
}
|
||||
}
|
||||
|
||||
public int getIndex() {
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class TruncEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public TruncEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
|
||||
if(2 != containedEvaluators.size()){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 2 values but found %d",expression,containedEvaluators.size()));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object value1, Object value2){
|
||||
if(null == value1){
|
||||
return null;
|
||||
}
|
||||
|
||||
int endIndex = ((Number)value2).intValue();
|
||||
|
||||
if(value1 instanceof List){
|
||||
return ((List<?>)value1).stream().map(innerValue -> doWork(innerValue, endIndex)).collect(Collectors.toList());
|
||||
}
|
||||
else {
|
||||
return value1.toString().substring(0, endIndex);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -57,7 +57,7 @@ public class TestLang extends SolrTestCase {
|
|||
"triangularDistribution", "precision", "minMaxScale", "markovChain", "grandSum",
|
||||
"scalarAdd", "scalarSubtract", "scalarMultiply", "scalarDivide", "sumRows",
|
||||
"sumColumns", "diff", "corrPValues", "normalizeSum", "geometricDistribution", "olsRegress",
|
||||
"derivative", "spline", "ttest", "pairedTtest", "multiVariateNormalDistribution", "integrate",
|
||||
"derivative", "spline", "ttest", "pairedTtest", "multiVariateNormalDistribution", "integral",
|
||||
"density", "mannWhitney", "sumSq", "akima", "lerp", "chiSquareDataSet", "gtestDataSet",
|
||||
"termVectors", "getColumnLabels", "getRowLabels", "getAttribute", "kmeans", "getCentroids",
|
||||
"getCluster", "topFeatures", "featureSelect", "rowAt", "colAt", "setColumnLabels",
|
||||
|
@ -77,7 +77,7 @@ public class TestLang extends SolrTestCase {
|
|||
"getSupportPoints", "pairSort", "log10", "plist", "recip", "pivot", "ltrim", "rtrim", "export",
|
||||
"zplot", "natural", "repeat", "movingMAD", "hashRollup", "noop", "var", "stddev", "recNum", "isNull",
|
||||
"notNull", "matches", "projectToBorder", "double", "long", "parseCSV", "parseTSV", "dateTime",
|
||||
"split", "upper", "trim", "lower"};
|
||||
"split", "upper", "trim", "lower", "trunc", "cosine"};
|
||||
|
||||
@Test
|
||||
public void testLang() {
|
||||
|
|
|
@ -229,6 +229,27 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
assertEquals(s2, "c-d-hello");
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testTrunc() throws Exception {
|
||||
String expr = " select(list(tuple(field1=\"abcde\", field2=\"012345\")), trunc(field1, 2) as field3, trunc(field2, 4) as field4)";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", expr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString() + "/" + COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertEquals(tuples.size(), 1);
|
||||
String s1 = tuples.get(0).getString("field3");
|
||||
assertEquals(s1, "ab");
|
||||
String s2 = tuples.get(0).getString("field4");
|
||||
assertEquals(s2, "0123");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUpperLowerSingle() throws Exception {
|
||||
String expr = " select(list(tuple(field1=\"a\", field2=\"C\")), upper(field1) as field3, lower(field2) as field4)";
|
||||
|
@ -249,6 +270,28 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
assertEquals(s2, "c");
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testTruncArray() throws Exception {
|
||||
String expr = " select(list(tuple(field1=array(\"aaaa\",\"bbbb\",\"cccc\"))), trunc(field1, 3) as field2)";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", expr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString() + "/" + COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertEquals(tuples.size(), 1);
|
||||
List<String> l1 = (List<String>)tuples.get(0).get("field2");
|
||||
assertEquals(l1.get(0), "aaa");
|
||||
assertEquals(l1.get(1), "bbb");
|
||||
assertEquals(l1.get(2), "ccc");
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUpperLowerArray() throws Exception {
|
||||
String expr = " select(list(tuple(field1=array(\"a\",\"b\",\"c\"), field2=array(\"X\",\"Y\",\"Z\"))), upper(field1) as field3, lower(field2) as field4)";
|
||||
|
@ -722,6 +765,27 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
assertTrue(tuples.get(0).getDouble("cov").equals(-625.0D));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCosineDistance() throws Exception {
|
||||
String cexpr = "let(echo=true, " +
|
||||
"a=array(1,2,3,4)," +
|
||||
"b=array(10, 20, 30, 45), " +
|
||||
"c=distance(a, b, cosine()), " +
|
||||
")";
|
||||
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString() + "/" + COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
Number d = (Number) tuples.get(0).get("c");
|
||||
assertEquals(d.doubleValue(), 0.0017046159, 0.0001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDistance() throws Exception {
|
||||
String cexpr = "let(echo=true, " +
|
||||
|
@ -3343,7 +3407,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
Number cs = (Number)tuples.get(0).get("return-value");
|
||||
assertTrue(cs.doubleValue() == 0.9838197164968291);
|
||||
assertEquals(cs.doubleValue(),0.9838197164968291, .00000001);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -4085,9 +4149,10 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
String cexpr = "let(echo=true, " +
|
||||
"a=sequence(50, 1, 0), " +
|
||||
"b=spline(a), " +
|
||||
"c=integrate(b, 0, 49), " +
|
||||
"d=integrate(b, 0, 20), " +
|
||||
"e=integrate(b, 20, 49))";
|
||||
"c=integral(b, 0, 49), " +
|
||||
"d=integral(b, 0, 20), " +
|
||||
"e=integral(b, 20, 49)," +
|
||||
"f=integral(b))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
|
@ -4103,6 +4168,9 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
assertEquals(integral.doubleValue(), 20, 0.0);
|
||||
integral = (Number)tuples.get(0).get("e");
|
||||
assertEquals(integral.doubleValue(), 29, 0.0);
|
||||
List<Number> integrals = (List<Number>)tuples.get(0).get("f");
|
||||
assertEquals(integrals.size(), 50);
|
||||
assertEquals(integrals.get(49).intValue(), 49);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -4313,6 +4381,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testLerp() throws Exception {
|
||||
String cexpr = "let(echo=true," +
|
||||
|
|
Loading…
Reference in New Issue