From 5e2ef5eb73d23cd98af2ebec5cc14730d19c4ca4 Mon Sep 17 00:00:00 2001 From: Joel Bernstein Date: Tue, 16 Jan 2018 19:19:45 -0500 Subject: [PATCH] SOLR-11736: Rename knn Streaming Expression to knnSearch and add new knn Stream Evaluator --- .../apache/solr/handler/StreamHandler.java | 4 +- .../solrj/io/eval/GetAttributesEvaluator.java | 42 +++++ .../client/solrj/io/eval/KnnEvaluator.java | 170 ++++++++++++++++++ .../solrj/io/stream/StreamExpressionTest.java | 73 +++++++- 4 files changed, 282 insertions(+), 7 deletions(-) create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetAttributesEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java index 206136c1a83..aa602860189 100644 --- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java +++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java @@ -127,7 +127,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, .withFunctionName("topic", TopicStream.class) .withFunctionName("commit", CommitStream.class) .withFunctionName("random", RandomStream.class) - .withFunctionName("knn", KnnStream.class) + .withFunctionName("knnSearch", KnnStream.class) // decorator streams .withFunctionName("merge", MergeStream.class) @@ -305,6 +305,8 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, .withFunctionName("colAt", ColumnAtEvaluator.class) .withFunctionName("setColumnLabels", SetColumnLabelsEvaluator.class) .withFunctionName("setRowLabels", SetRowLabelsEvaluator.class) + .withFunctionName("knn", KnnEvaluator.class) + .withFunctionName("getAttributes", GetAttributesEvaluator.class) // Boolean Stream Evaluators diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetAttributesEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetAttributesEvaluator.java new file mode 100644 index 00000000000..b1c846e31ba --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetAttributesEvaluator.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.Locale; + +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class GetAttributesEvaluator extends RecursiveObjectEvaluator implements OneValueWorker { + private static final long serialVersionUID = 1; + + public GetAttributesEvaluator(StreamExpression expression, StreamFactory factory) throws IOException { + super(expression, factory); + } + + @Override + public Object doWork(Object value) throws IOException { + if(!(value instanceof Attributes)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting an Attributes",toExpression(constructingFactory), value.getClass().getSimpleName())); + } else { + Attributes attributes = (Attributes)value; + return attributes.getAttributes(); + } + } +} \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java new file mode 100644 index 00000000000..665530eb45c --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.TreeSet; + +import org.apache.commons.math3.ml.distance.CanberraDistance; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EarthMoversDistance; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.apache.commons.math3.ml.distance.ManhattanDistance; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker { + protected static final long serialVersionUID = 1L; + + private DistanceMeasure distanceMeasure; + + public KnnEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + + DistanceEvaluator.DistanceType type = null; + List namedParams = factory.getNamedOperands(expression); + if(namedParams.size() > 0) { + if (namedParams.size() > 1) { + throw new IOException("distance function expects only one named parameter 'distance'."); + } + + StreamExpressionNamedParameter namedParameter = namedParams.get(0); + String name = namedParameter.getName(); + if (!name.equalsIgnoreCase("distance")) { + throw new IOException("distance function expects only one named parameter 'distance'."); + } + + String typeParam = namedParameter.getParameter().toString().trim(); + type= DistanceEvaluator.DistanceType.valueOf(typeParam); + } else { + type = DistanceEvaluator.DistanceType.euclidean; + } + + if (type.equals(DistanceEvaluator.DistanceType.euclidean)) { + distanceMeasure = new EuclideanDistance(); + } else if (type.equals(DistanceEvaluator.DistanceType.manhattan)) { + distanceMeasure = new ManhattanDistance(); + } else if (type.equals(DistanceEvaluator.DistanceType.canberra)) { + distanceMeasure = new CanberraDistance(); + } else if (type.equals(DistanceEvaluator.DistanceType.earthMovers)) { + distanceMeasure = new EarthMoversDistance(); + } + + } + + @Override + public Object doWork(Object... values) throws IOException { + + if(values.length < 3) { + throw new IOException("knn expects three parameters a Matrix, numeric array and k"); + } + + Matrix matrix = null; + double[] vec = null; + int k = 0; + + if(values[0] instanceof Matrix) { + matrix = (Matrix)values[0]; + } else { + throw new IOException("The first parameter for knn should be a matrix."); + } + + if(values[1] instanceof List) { + List nums = (List)values[1]; + vec = new double[nums.size()]; + for(int i=0; i neighbors = new TreeSet(); + for(int i=0; i k) { + neighbors.pollLast(); + } + } + + double[][] out = new double[neighbors.size()][]; + List rowLabels = matrix.getRowLabels(); + List newRowLabels = new ArrayList(); + List distances = new ArrayList(); + int i=-1; + + while(neighbors.size() > 0) { + Neighbor neighbor = neighbors.pollFirst(); + int rowIndex = neighbor.getRow(); + + if(rowLabels != null) { + newRowLabels.add(rowLabels.get(rowIndex)); + } + + out[++i] = data[rowIndex]; + distances.add(neighbor.getDistance()); + } + + Matrix knn = new Matrix(out); + + if(rowLabels != null) { + knn.setRowLabels(newRowLabels); + } + + knn.setColumnLabels(matrix.getColumnLabels()); + knn.setAttribute("distances", distances); + return knn; + } + + public static class Neighbor implements Comparable { + + private Double distance; + private int row; + + public Neighbor(int row, double distance) { + this.distance = distance; + this.row = row; + } + + public int getRow() { + return this.row; + } + + public Double getDistance() { + return distance; + } + + public int compareTo(Neighbor neighbor) { + return this.distance.compareTo(neighbor.getDistance()); + } + } + +} + diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java index 6f1e61f9aaf..1493562b571 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java @@ -933,7 +933,7 @@ public class StreamExpressionTest extends SolrCloudTestCase { } @Test - public void testKnnStream() throws Exception { + public void testKnnSearchStream() throws Exception { UpdateRequest update = new UpdateRequest(); update.add(id, "1", "a_t", "hello world have a very nice day blah"); @@ -947,7 +947,7 @@ public class StreamExpressionTest extends SolrCloudTestCase { try { context.setSolrClientCache(cache); ModifiableSolrParams sParams = new ModifiableSolrParams(StreamingTest.mapParams(CommonParams.QT, "/stream")); - sParams.add("expr", "knn(" + COLLECTIONORALIAS + ", id=\"1\", qf=\"a_t\", rows=\"4\", fl=\"id, score\", mintf=\"1\")"); + sParams.add("expr", "knnSearch(" + COLLECTIONORALIAS + ", id=\"1\", qf=\"a_t\", rows=\"4\", fl=\"id, score\", mintf=\"1\")"); JettySolrRunner jetty = cluster.getJettySolrRunner(0); SolrStream solrStream = new SolrStream(jetty.getBaseUrl().toString() + "/collection1", sParams); List tuples = getTuples(solrStream); @@ -955,26 +955,26 @@ public class StreamExpressionTest extends SolrCloudTestCase { assertOrder(tuples, 2, 3, 4); sParams = new ModifiableSolrParams(StreamingTest.mapParams(CommonParams.QT, "/stream")); - sParams.add("expr", "knn(" + COLLECTIONORALIAS + ", id=\"1\", qf=\"a_t\", k=\"2\", fl=\"id, score\", mintf=\"1\")"); + sParams.add("expr", "knnSearch(" + COLLECTIONORALIAS + ", id=\"1\", qf=\"a_t\", k=\"2\", fl=\"id, score\", mintf=\"1\")"); solrStream = new SolrStream(jetty.getBaseUrl().toString() + "/collection1", sParams); tuples = getTuples(solrStream); assertTrue(tuples.size() == 2); assertOrder(tuples, 2, 3); sParams = new ModifiableSolrParams(StreamingTest.mapParams(CommonParams.QT, "/stream")); - sParams.add("expr", "knn(" + COLLECTIONORALIAS + ", id=\"1\", qf=\"a_t\", rows=\"4\", fl=\"id, score\", mintf=\"1\", maxdf=\"0\")"); + sParams.add("expr", "knnSearch(" + COLLECTIONORALIAS + ", id=\"1\", qf=\"a_t\", rows=\"4\", fl=\"id, score\", mintf=\"1\", maxdf=\"0\")"); solrStream = new SolrStream(jetty.getBaseUrl().toString() + "/collection1", sParams); tuples = getTuples(solrStream); assertTrue(tuples.size() == 0); sParams = new ModifiableSolrParams(StreamingTest.mapParams(CommonParams.QT, "/stream")); - sParams.add("expr", "knn(" + COLLECTIONORALIAS + ", id=\"1\", qf=\"a_t\", rows=\"4\", fl=\"id, score\", mintf=\"1\", maxwl=\"1\")"); + sParams.add("expr", "knnSearch(" + COLLECTIONORALIAS + ", id=\"1\", qf=\"a_t\", rows=\"4\", fl=\"id, score\", mintf=\"1\", maxwl=\"1\")"); solrStream = new SolrStream(jetty.getBaseUrl().toString() + "/collection1", sParams); tuples = getTuples(solrStream); assertTrue(tuples.size() == 0); sParams = new ModifiableSolrParams(StreamingTest.mapParams(CommonParams.QT, "/stream")); - sParams.add("expr", "knn(" + COLLECTIONORALIAS + ", id=\"1\", qf=\"a_t\", rows=\"2\", fl=\"id, score\", mintf=\"1\", minwl=\"20\")"); + sParams.add("expr", "knnSearch(" + COLLECTIONORALIAS + ", id=\"1\", qf=\"a_t\", rows=\"2\", fl=\"id, score\", mintf=\"1\", minwl=\"20\")"); solrStream = new SolrStream(jetty.getBaseUrl().toString() + "/collection1", sParams); tuples = getTuples(solrStream); assertTrue(tuples.size() == 0); @@ -7734,6 +7734,67 @@ public class StreamExpressionTest extends SolrCloudTestCase { assertEquals(density.doubleValue(), 0.007852638121596995, .00001); } + @Test + public void testKnn() throws Exception { + String cexpr = "let(echo=true," + + " a=setRowLabels(matrix(array(1,1,1,0,0,0),"+ + " array(1,0,0,0,1,1),"+ + " array(0,0,0,1,1,1)), array(row1,row2,row3)),"+ + " b=array(0,0,0,1,1,1),"+ + " c=knn(a, b, 2),"+ + " d=getRowLabels(c),"+ + " e=getAttributes(c)," + + " f=knn(a, b, 2, distance=manhattan)," + + " g=getAttributes(f))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + + List> knnMatrix = (List>)tuples.get(0).get("c"); + assertEquals(knnMatrix.size(), 2); + + List row1 = knnMatrix.get(0); + assertEquals(row1.size(), 6); + assertEquals(row1.get(0).doubleValue(), 0.0, 0.0); + assertEquals(row1.get(1).doubleValue(), 0.0, 0.0); + assertEquals(row1.get(2).doubleValue(), 0.0, 0.0); + assertEquals(row1.get(3).doubleValue(), 1.0, 0.0); + assertEquals(row1.get(4).doubleValue(), 1.0, 0.0); + assertEquals(row1.get(5).doubleValue(), 1.0, 0.0); + + List row2 = knnMatrix.get(1); + assertEquals(row2.size(), 6); + + assertEquals(row2.get(0).doubleValue(), 1.0, 0.0); + assertEquals(row2.get(1).doubleValue(), 0.0, 0.0); + assertEquals(row2.get(2).doubleValue(), 0.0, 0.0); + assertEquals(row2.get(3).doubleValue(), 0.0, 0.0); + assertEquals(row2.get(4).doubleValue(), 1.0, 0.0); + assertEquals(row2.get(5).doubleValue(), 1.0, 0.0); + + Map atts = (Map)tuples.get(0).get("e"); + List dists = (List)atts.get("distances"); + assertEquals(dists.size(), 2); + assertEquals(dists.get(0).doubleValue(), 0.0, 0.0); + assertEquals(dists.get(1).doubleValue(), 1.4142135623730951, 0.0); + + List rowLabels = (List)tuples.get(0).get("d"); + assertEquals(rowLabels.size(), 2); + assertEquals(rowLabels.get(0), "row3"); + assertEquals(rowLabels.get(1), "row2"); + + atts = (Map)tuples.get(0).get("g"); + dists = (List)atts.get("distances"); + assertEquals(dists.size(), 2); + assertEquals(dists.get(0).doubleValue(), 0.0, 0.0); + assertEquals(dists.get(1).doubleValue(), 2.0, 0.0); + } @Test public void testIntegrate() throws Exception {