SOLR-12273: Create Stream Evaluators for distance measures

This commit is contained in:
Joel Bernstein 2018-04-25 21:37:50 -04:00
parent aa341476fd
commit bea6f42105
10 changed files with 321 additions and 123 deletions

View File

@ -237,6 +237,11 @@ public class Lang {
.withFunctionName("memset", MemsetEvaluator.class)
.withFunctionName("fft", FFTEvaluator.class)
.withFunctionName("ifft", IFFTEvaluator.class)
.withFunctionName("manhattan", ManhattanEvaluator.class)
.withFunctionName("canberra", CanberraEvaluator.class)
.withFunctionName("earthMovers", EarthMoversEvaluator.class)
.withFunctionName("euclidean", EuclideanEvaluator.class)
.withFunctionName("chebyshev", ChebyshevEvaluator.class)
// Boolean Stream Evaluators

View File

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import org.apache.commons.math3.ml.distance.CanberraDistance;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class CanberraEvaluator extends RecursiveEvaluator {
protected static final long serialVersionUID = 1L;
public CanberraEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
public CanberraEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
super(expression, factory, ignoredNamedParameters);
}
@Override
public Object evaluate(Tuple tuple) throws IOException {
return new CanberraDistance();
}
@Override
public Object doWork(Object... values) throws IOException {
// Nothing to do here
throw new IOException("This call should never occur");
}
}

View File

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import org.apache.commons.math3.ml.distance.ChebyshevDistance;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class ChebyshevEvaluator extends RecursiveEvaluator {
protected static final long serialVersionUID = 1L;
public ChebyshevEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
public ChebyshevEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
super(expression, factory, ignoredNamedParameters);
}
@Override
public Object evaluate(Tuple tuple) throws IOException {
return new ChebyshevDistance();
}
@Override
public Object doWork(Object... values) throws IOException {
// Nothing to do here
throw new IOException("This call should never occur");
}
}

View File

@ -22,14 +22,9 @@ import java.util.List;
import java.util.Locale;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.ml.distance.CanberraDistance;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.EarthMoversDistance;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.commons.math3.ml.distance.ManhattanDistance;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class DistanceEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
@ -40,94 +35,75 @@ public class DistanceEvaluator extends RecursiveObjectEvaluator implements ManyV
public DistanceEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
if(namedParams.size() > 0) {
if (namedParams.size() > 1) {
throw new IOException("distance function expects only one named parameter 'type'.");
}
StreamExpressionNamedParameter namedParameter = namedParams.get(0);
String name = namedParameter.getName();
if (!name.equalsIgnoreCase("type")) {
throw new IOException("distance function expects only one named parameter 'type'.");
}
String typeParam = namedParameter.getParameter().toString().trim();
this.type= DistanceType.valueOf(typeParam);
} else {
this.type = DistanceType.euclidean;
}
}
@Override
public Object doWork(Object ... values) throws IOException{
if(values.length == 2) {
if(values.length == 1) {
if (values[0] instanceof Matrix) {
Matrix matrix = (Matrix) values[0];
EuclideanDistance euclideanDistance = new EuclideanDistance();
return distance(euclideanDistance, matrix);
} else {
throw new IOException("distance function operates on either two numeric arrays or a single matrix as parameters.");
}
} else if(values.length == 2) {
Object first = values[0];
Object second = values[1];
if (null == first) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the first value", toExpression(constructingFactory)));
}
if (null == second) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the second value", toExpression(constructingFactory)));
}
if(first instanceof Matrix) {
Matrix matrix = (Matrix) first;
DistanceMeasure distanceMeasure = (DistanceMeasure)second;
return distance(distanceMeasure, matrix);
} else {
if (!(first instanceof List<?>)) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a list of numbers", toExpression(constructingFactory), first.getClass().getSimpleName()));
}
if (!(second instanceof List<?>)) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the second value, expecting a list of numbers", toExpression(constructingFactory), first.getClass().getSimpleName()));
}
if (type.equals(DistanceType.euclidean)) {
EuclideanDistance euclideanDistance = new EuclideanDistance();
return euclideanDistance.compute(
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
);
} else if (type.equals(DistanceType.manhattan)) {
ManhattanDistance manhattanDistance = new ManhattanDistance();
return manhattanDistance.compute(
DistanceMeasure distanceMeasure = new EuclideanDistance();
return distanceMeasure.compute(
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
);
}
} else if (values.length == 3) {
Object first = values[0];
Object second = values[1];
DistanceMeasure distanceMeasure = (DistanceMeasure)values[2];
} else if (type.equals(DistanceType.canberra)) {
CanberraDistance canberraDistance = new CanberraDistance();
return canberraDistance.compute(
if (null == first) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the first value", toExpression(constructingFactory)));
}
if (null == second) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the second value", toExpression(constructingFactory)));
}
if (!(first instanceof List<?>)) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a list of numbers", toExpression(constructingFactory), first.getClass().getSimpleName()));
}
if (!(second instanceof List<?>)) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the second value, expecting a list of numbers", toExpression(constructingFactory), first.getClass().getSimpleName()));
}
return distanceMeasure.compute(
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
);
} else if (type.equals(DistanceType.earthMovers)) {
EarthMoversDistance earthMoversDistance = new EarthMoversDistance();
return earthMoversDistance.compute(
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
);
} else {
return null;
}
} else if(values.length == 1) {
if(values[0] instanceof Matrix) {
Matrix matrix = (Matrix)values[0];
if (type.equals(DistanceType.euclidean)) {
EuclideanDistance euclideanDistance = new EuclideanDistance();
return distance(euclideanDistance, matrix);
} else if (type.equals(DistanceType.canberra)) {
CanberraDistance canberraDistance = new CanberraDistance();
return distance(canberraDistance, matrix);
} else if (type.equals(DistanceType.manhattan)) {
ManhattanDistance manhattanDistance = new ManhattanDistance();
return distance(manhattanDistance, matrix);
} else if (type.equals(DistanceType.earthMovers)) {
EarthMoversDistance earthMoversDistance = new EarthMoversDistance();
return distance(earthMoversDistance, matrix);
} else {
return null;
}
} else {
throw new IOException("distance function operates on either two numeric arrays or a single matrix as parameters.");
}
} else {
throw new IOException("distance function operates on either two numeric arrays or a single matrix as parameters.");
}

View File

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import org.apache.commons.math3.ml.distance.EarthMoversDistance;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class EarthMoversEvaluator extends RecursiveEvaluator {
protected static final long serialVersionUID = 1L;
public EarthMoversEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
public EarthMoversEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
super(expression, factory, ignoredNamedParameters);
}
@Override
public Object evaluate(Tuple tuple) throws IOException {
return new EarthMoversDistance();
}
@Override
public Object doWork(Object... values) throws IOException {
// Nothing to do here
throw new IOException("This call should never occur");
}
}

View File

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class EuclideanEvaluator extends RecursiveEvaluator {
protected static final long serialVersionUID = 1L;
public EuclideanEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
public EuclideanEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
super(expression, factory, ignoredNamedParameters);
}
@Override
public Object evaluate(Tuple tuple) throws IOException {
return new EuclideanDistance();
}
@Override
public Object doWork(Object... values) throws IOException {
// Nothing to do here
throw new IOException("This call should never occur");
}
}

View File

@ -22,52 +22,16 @@ import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;
import org.apache.commons.math3.ml.distance.CanberraDistance;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.EarthMoversDistance;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.commons.math3.ml.distance.ManhattanDistance;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
protected static final long serialVersionUID = 1L;
private DistanceMeasure distanceMeasure;
public KnnEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
DistanceEvaluator.DistanceType type = null;
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
if(namedParams.size() > 0) {
if (namedParams.size() > 1) {
throw new IOException("distance function expects only one named parameter 'distance'.");
}
StreamExpressionNamedParameter namedParameter = namedParams.get(0);
String name = namedParameter.getName();
if (!name.equalsIgnoreCase("distance")) {
throw new IOException("distance function expects only one named parameter 'distance'.");
}
String typeParam = namedParameter.getParameter().toString().trim();
type= DistanceEvaluator.DistanceType.valueOf(typeParam);
} else {
type = DistanceEvaluator.DistanceType.euclidean;
}
if (type.equals(DistanceEvaluator.DistanceType.euclidean)) {
distanceMeasure = new EuclideanDistance();
} else if (type.equals(DistanceEvaluator.DistanceType.manhattan)) {
distanceMeasure = new ManhattanDistance();
} else if (type.equals(DistanceEvaluator.DistanceType.canberra)) {
distanceMeasure = new CanberraDistance();
} else if (type.equals(DistanceEvaluator.DistanceType.earthMovers)) {
distanceMeasure = new EarthMoversDistance();
}
}
@Override
@ -105,6 +69,14 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
double[][] data = matrix.getData();
DistanceMeasure distanceMeasure = null;
if(values.length == 4) {
distanceMeasure = (DistanceMeasure)values[3];
} else {
distanceMeasure = new EuclideanDistance();
}
TreeSet<Neighbor> neighbors = new TreeSet();
for(int i=0; i<data.length; i++) {
double distance = distanceMeasure.compute(vec, data[i]);
@ -165,6 +137,5 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
return this.distance.compareTo(neighbor.getDistance());
}
}
}

View File

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import org.apache.commons.math3.ml.distance.ManhattanDistance;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class ManhattanEvaluator extends RecursiveEvaluator {
protected static final long serialVersionUID = 1L;
public ManhattanEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
public ManhattanEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
super(expression, factory, ignoredNamedParameters);
}
@Override
public Object evaluate(Tuple tuple) throws IOException {
return new ManhattanDistance();
}
@Override
public Object doWork(Object... values) throws IOException {
// Nothing to do here
throw new IOException("This call should never occur");
}
}

View File

@ -68,7 +68,8 @@ public class TestLang extends LuceneTestCase {
TemporalEvaluatorEpoch.FUNCTION_NAME, TemporalEvaluatorWeek.FUNCTION_NAME, TemporalEvaluatorQuarter.FUNCTION_NAME,
TemporalEvaluatorDayOfQuarter.FUNCTION_NAME, "abs", "add", "div", "mult", "sub", "log", "pow",
"mod", "ceil", "floor", "sin", "asin", "sinh", "cos", "acos", "cosh", "tan", "atan", "tanh", "round", "sqrt",
"cbrt", "coalesce", "uuid", "if", "convert", "valueAt", "memset", "fft", "ifft"};
"cbrt", "coalesce", "uuid", "if", "convert", "valueAt", "memset", "fft", "ifft", "euclidean","manhattan",
"earthMovers", "canberra", "chebyshev"};
@Test
public void testLang() {

View File

@ -481,21 +481,21 @@ public class MathExpressionTest extends SolrCloudTestCase {
"f=distance(b, c)," +
"g=transpose(matrix(a, b, c))," +
"h=distance(g)," +
"i=distance(a, b, type=manhattan), " +
"j=distance(a, c, type=manhattan)," +
"k=distance(b, c, type=manhattan)," +
"i=distance(a, b, manhattan()), " +
"j=distance(a, c, manhattan())," +
"k=distance(b, c, manhattan())," +
"l=transpose(matrix(a, b, c))," +
"m=distance(l, type=manhattan)," +
"n=distance(a, b, type=canberra), " +
"o=distance(a, c, type=canberra)," +
"p=distance(b, c, type=canberra)," +
"m=distance(l, manhattan())," +
"n=distance(a, b, canberra()), " +
"o=distance(a, c, canberra())," +
"p=distance(b, c, canberra())," +
"q=transpose(matrix(a, b, c))," +
"r=distance(q, type=canberra)," +
"s=distance(a, b, type=earthMovers), " +
"t=distance(a, c, type=earthMovers)," +
"u=distance(b, c, type=earthMovers)," +
"r=distance(q, canberra())," +
"s=distance(a, b, earthMovers()), " +
"t=distance(a, c, earthMovers())," +
"u=distance(b, c, earthMovers())," +
"w=transpose(matrix(a, b, c))," +
"x=distance(w, type=earthMovers)," +
"x=distance(w, earthMovers())," +
")";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
@ -2946,7 +2946,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
" c=knn(a, b, 2),"+
" d=getRowLabels(c),"+
" e=getAttributes(c)," +
" f=knn(a, b, 2, distance=manhattan)," +
" f=knn(a, b, 2, manhattan())," +
" g=getAttributes(f))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);