This commit is contained in:
Karl Wright 2018-04-26 02:59:23 -04:00
commit d53de2a385
13 changed files with 382 additions and 160 deletions

View File

@ -191,6 +191,9 @@ Bug Fixes
get an error that's it's a member of that alias. This check now ensures the alias state is sync()'ed with ZK first. get an error that's it's a member of that alias. This check now ensures the alias state is sync()'ed with ZK first.
(David Smiley) (David Smiley)
* SOLR-12275: wrong caching for {!filters} as well as for `filters` local param in {!parent} and {!child}
(David Smiley, Mikhail Khluldnev)
Optimizations Optimizations
---------------------- ----------------------

View File

@ -17,8 +17,6 @@
package org.apache.solr.search.join; package org.apache.solr.search.join;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.HashSet; import java.util.HashSet;
import java.util.IdentityHashMap; import java.util.IdentityHashMap;
@ -36,31 +34,10 @@ import org.apache.solr.common.util.StrUtils;
import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.search.QParser; import org.apache.solr.search.QParser;
import org.apache.solr.search.QueryParsing; import org.apache.solr.search.QueryParsing;
import org.apache.solr.search.SolrConstantScoreQuery;
import org.apache.solr.search.SyntaxError; import org.apache.solr.search.SyntaxError;
public class FiltersQParser extends QParser { public class FiltersQParser extends QParser {
private static final class FiltersQuery extends SolrConstantScoreQuery {
private Set<Query> filters;
private FiltersQuery(SolrQueryRequest req, Set<Query> filters) throws IOException {
super(req.getSearcher().getDocSet(new ArrayList<>(filters)).getTopFilter());
this.filters = filters;
}
@Override
public boolean equals(Object other) {
return sameClassAs(other) && filters.equals(((FiltersQuery)other).filters);
}
@Override
public int hashCode() {
return 31 * classHash() + filters.hashCode();
}
}
protected String getFiltersParamName() { protected String getFiltersParamName() {
return "param"; return "param";
} }
@ -104,21 +81,14 @@ public class FiltersQParser extends QParser {
/** @return number of added clauses */ /** @return number of added clauses */
protected int addFilters(BooleanQuery.Builder builder, Map<Query,Occur> clauses) throws SyntaxError { protected int addFilters(BooleanQuery.Builder builder, Map<Query,Occur> clauses) throws SyntaxError {
Set<Query> filters = new HashSet<>(); int count=0;
for (Map.Entry<Query, Occur> clause: clauses.entrySet()) { for (Map.Entry<Query, Occur> clause: clauses.entrySet()) {
if (clause.getValue() == Occur.FILTER) { if (clause.getValue() == Occur.FILTER) {
filters.add(clause.getKey()); builder.add( clause.getKey(), Occur.FILTER);
count++;
} }
} }
if (!filters.isEmpty()) { return count;
try {
final SolrConstantScoreQuery intersQuery = new FiltersQuery(req, filters);
builder.add( intersQuery, Occur.FILTER);
} catch (IOException e) {
throw new SyntaxError("Exception occurs while parsing " + stringIncludingLocalParams, e);
}
}
return filters.size();
} }
protected void exclude(Map<Query,Occur> clauses) { protected void exclude(Map<Query,Occur> clauses) {

View File

@ -18,6 +18,7 @@ package org.apache.solr.search.join;
import javax.xml.xpath.XPathConstants; import javax.xml.xpath.XPathConstants;
import java.io.IOException; import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -26,18 +27,26 @@ import java.util.ListIterator;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.search.join.ScoreMode;
import org.apache.solr.SolrTestCaseJ4; import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrException;
import org.apache.solr.common.SolrException.ErrorCode; import org.apache.solr.common.SolrException.ErrorCode;
import org.apache.solr.metrics.MetricsMap; import org.apache.solr.metrics.MetricsMap;
import org.apache.solr.metrics.SolrMetricManager; import org.apache.solr.metrics.SolrMetricManager;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.search.QParser;
import org.apache.solr.search.SyntaxError;
import org.apache.solr.util.BaseTestHarness; import org.apache.solr.util.BaseTestHarness;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class BJQParserTest extends SolrTestCaseJ4 { public class BJQParserTest extends SolrTestCaseJ4 {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private static final String[] klm = new String[] {"k", "l", "m"}; private static final String[] klm = new String[] {"k", "l", "m"};
private static final List<String> xyz = Arrays.asList("x", "y", "z"); private static final List<String> xyz = Arrays.asList("x", "y", "z");
private static final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; private static final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"};
@ -389,8 +398,8 @@ public class BJQParserTest extends SolrTestCaseJ4 {
} }
private final static String elChild[] = new String[]{"//*[@numFound='1']", private final static String elChild[] = new String[]{"//*[@numFound='1']",
"//doc[" + "//doc[" +
"arr[@name=\"child_s\"]/str='l' and child::arr[@name=\"childparent_s\"]/str='e']"}; "arr[@name=\"child_s\"]/str='l' and arr[@name=\"childparent_s\"]/str='e']"};
@Test @Test
@ -441,10 +450,52 @@ public class BJQParserTest extends SolrTestCaseJ4 {
"child.fq", "{!tag=secondTag}child_s:l", // 6 ls remains "child.fq", "{!tag=secondTag}child_s:l", // 6 ls remains
"gchq", "{!tag=top}childparent_s:e"), "//*[@numFound='6']"); "gchq", "{!tag=top}childparent_s:e"), "//*[@numFound='6']");
assertQ(req("q", // top and filter are excluded, got zero result, but is it right? assertQ(req("q", // top and filter are excluded, got all results
"{!filters excludeTags=bot,secondTag,top v=$gchq}" , "{!filters excludeTags=bot,secondTag,top v=$gchq}" ,
"child.fq", "{!tag=secondTag}child_s:l", "child.fq", "{!tag=secondTag}child_s:l",
"gchq", "{!tag=top}childparent_s:e"), "//*[@numFound='42']"); "gchq", "{!tag=top}childparent_s:e"), "//*[@numFound='42']");
} }
@Test
public void testFiltersCache() throws SyntaxError, IOException {
final String [] elFilterQuery = new String[] {"q", "{!filters param=$child.fq v=$gchq}",
"child.fq", "childparent_s:e",
"child.fq", "child_s:l",
"gchq", "child_s:[* TO *]"};
assertQ(req(elFilterQuery), elChild);
final Query query;
{
final SolrQueryRequest req = req(elFilterQuery);
try {
QParser parser = QParser.getParser(req.getParams().get("q"), null, req);
query = parser.getQuery();
final TopDocs topDocs = req.getSearcher().search(query, 10);
assertEquals(1, topDocs.totalHits);
}finally {
req.close();
}
}
assertU(adoc("id", "12275",
"child_s", "l", "childparent_s", "e"));
assertU(commit());
try {
assertQ("here we rely on autowarming for cathing cache leak", //cache=false
req(elFilterQuery), "//*[@numFound='2']");
final SolrQueryRequest req = req();
try {
final TopDocs topDocs = req.getSearcher().search(query, 10);
assertEquals("expecting new doc is visible to old query", 2, topDocs.totalHits);
}finally {
req.close();
}
}finally {
try {
assertU(delI("12275"));
assertU(commit());
} catch(Throwable t) {
log.error("ignoring exception occuring in compensation routine", t);
}
}
}
} }

View File

@ -237,6 +237,11 @@ public class Lang {
.withFunctionName("memset", MemsetEvaluator.class) .withFunctionName("memset", MemsetEvaluator.class)
.withFunctionName("fft", FFTEvaluator.class) .withFunctionName("fft", FFTEvaluator.class)
.withFunctionName("ifft", IFFTEvaluator.class) .withFunctionName("ifft", IFFTEvaluator.class)
.withFunctionName("manhattan", ManhattanEvaluator.class)
.withFunctionName("canberra", CanberraEvaluator.class)
.withFunctionName("earthMovers", EarthMoversEvaluator.class)
.withFunctionName("euclidean", EuclideanEvaluator.class)
.withFunctionName("chebyshev", ChebyshevEvaluator.class)
// Boolean Stream Evaluators // Boolean Stream Evaluators

View File

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import org.apache.commons.math3.ml.distance.CanberraDistance;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class CanberraEvaluator extends RecursiveEvaluator {
protected static final long serialVersionUID = 1L;
public CanberraEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
public CanberraEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
super(expression, factory, ignoredNamedParameters);
}
@Override
public Object evaluate(Tuple tuple) throws IOException {
return new CanberraDistance();
}
@Override
public Object doWork(Object... values) throws IOException {
// Nothing to do here
throw new IOException("This call should never occur");
}
}

View File

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import org.apache.commons.math3.ml.distance.ChebyshevDistance;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class ChebyshevEvaluator extends RecursiveEvaluator {
protected static final long serialVersionUID = 1L;
public ChebyshevEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
public ChebyshevEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
super(expression, factory, ignoredNamedParameters);
}
@Override
public Object evaluate(Tuple tuple) throws IOException {
return new ChebyshevDistance();
}
@Override
public Object doWork(Object... values) throws IOException {
// Nothing to do here
throw new IOException("This call should never occur");
}
}

View File

@ -22,14 +22,9 @@ import java.util.List;
import java.util.Locale; import java.util.Locale;
import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.ml.distance.CanberraDistance;
import org.apache.commons.math3.ml.distance.DistanceMeasure; import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.EarthMoversDistance;
import org.apache.commons.math3.ml.distance.EuclideanDistance; import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.commons.math3.ml.distance.ManhattanDistance;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class DistanceEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker { public class DistanceEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
@ -40,94 +35,75 @@ public class DistanceEvaluator extends RecursiveObjectEvaluator implements ManyV
public DistanceEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ public DistanceEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory); super(expression, factory);
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
if(namedParams.size() > 0) {
if (namedParams.size() > 1) {
throw new IOException("distance function expects only one named parameter 'type'.");
}
StreamExpressionNamedParameter namedParameter = namedParams.get(0);
String name = namedParameter.getName();
if (!name.equalsIgnoreCase("type")) {
throw new IOException("distance function expects only one named parameter 'type'.");
}
String typeParam = namedParameter.getParameter().toString().trim();
this.type= DistanceType.valueOf(typeParam);
} else {
this.type = DistanceType.euclidean;
}
} }
@Override @Override
public Object doWork(Object ... values) throws IOException{ public Object doWork(Object ... values) throws IOException{
if(values.length == 2) { if(values.length == 1) {
if (values[0] instanceof Matrix) {
Matrix matrix = (Matrix) values[0];
EuclideanDistance euclideanDistance = new EuclideanDistance();
return distance(euclideanDistance, matrix);
} else {
throw new IOException("distance function operates on either two numeric arrays or a single matrix as parameters.");
}
} else if(values.length == 2) {
Object first = values[0]; Object first = values[0];
Object second = values[1]; Object second = values[1];
if (null == first) { if (null == first) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the first value", toExpression(constructingFactory))); throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the first value", toExpression(constructingFactory)));
} }
if (null == second) { if (null == second) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the second value", toExpression(constructingFactory))); throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the second value", toExpression(constructingFactory)));
} }
if(first instanceof Matrix) {
Matrix matrix = (Matrix) first;
DistanceMeasure distanceMeasure = (DistanceMeasure)second;
return distance(distanceMeasure, matrix);
} else {
if (!(first instanceof List<?>)) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a list of numbers", toExpression(constructingFactory), first.getClass().getSimpleName()));
}
if (!(second instanceof List<?>)) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the second value, expecting a list of numbers", toExpression(constructingFactory), first.getClass().getSimpleName()));
}
DistanceMeasure distanceMeasure = new EuclideanDistance();
return distanceMeasure.compute(
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
);
}
} else if (values.length == 3) {
Object first = values[0];
Object second = values[1];
DistanceMeasure distanceMeasure = (DistanceMeasure)values[2];
if (null == first) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the first value", toExpression(constructingFactory)));
}
if (null == second) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the second value", toExpression(constructingFactory)));
}
if (!(first instanceof List<?>)) { if (!(first instanceof List<?>)) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a list of numbers", toExpression(constructingFactory), first.getClass().getSimpleName())); throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a list of numbers", toExpression(constructingFactory), first.getClass().getSimpleName()));
} }
if (!(second instanceof List<?>)) { if (!(second instanceof List<?>)) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the second value, expecting a list of numbers", toExpression(constructingFactory), first.getClass().getSimpleName())); throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the second value, expecting a list of numbers", toExpression(constructingFactory), first.getClass().getSimpleName()));
} }
if (type.equals(DistanceType.euclidean)) { return distanceMeasure.compute(
EuclideanDistance euclideanDistance = new EuclideanDistance(); ((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
return euclideanDistance.compute( ((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(), );
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
);
} else if (type.equals(DistanceType.manhattan)) {
ManhattanDistance manhattanDistance = new ManhattanDistance();
return manhattanDistance.compute(
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
);
} else if (type.equals(DistanceType.canberra)) {
CanberraDistance canberraDistance = new CanberraDistance();
return canberraDistance.compute(
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
);
} else if (type.equals(DistanceType.earthMovers)) {
EarthMoversDistance earthMoversDistance = new EarthMoversDistance();
return earthMoversDistance.compute(
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
);
} else {
return null;
}
} else if(values.length == 1) {
if(values[0] instanceof Matrix) {
Matrix matrix = (Matrix)values[0];
if (type.equals(DistanceType.euclidean)) {
EuclideanDistance euclideanDistance = new EuclideanDistance();
return distance(euclideanDistance, matrix);
} else if (type.equals(DistanceType.canberra)) {
CanberraDistance canberraDistance = new CanberraDistance();
return distance(canberraDistance, matrix);
} else if (type.equals(DistanceType.manhattan)) {
ManhattanDistance manhattanDistance = new ManhattanDistance();
return distance(manhattanDistance, matrix);
} else if (type.equals(DistanceType.earthMovers)) {
EarthMoversDistance earthMoversDistance = new EarthMoversDistance();
return distance(earthMoversDistance, matrix);
} else {
return null;
}
} else {
throw new IOException("distance function operates on either two numeric arrays or a single matrix as parameters.");
}
} else { } else {
throw new IOException("distance function operates on either two numeric arrays or a single matrix as parameters."); throw new IOException("distance function operates on either two numeric arrays or a single matrix as parameters.");
} }

View File

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import org.apache.commons.math3.ml.distance.EarthMoversDistance;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class EarthMoversEvaluator extends RecursiveEvaluator {
protected static final long serialVersionUID = 1L;
public EarthMoversEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
public EarthMoversEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
super(expression, factory, ignoredNamedParameters);
}
@Override
public Object evaluate(Tuple tuple) throws IOException {
return new EarthMoversDistance();
}
@Override
public Object doWork(Object... values) throws IOException {
// Nothing to do here
throw new IOException("This call should never occur");
}
}

View File

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class EuclideanEvaluator extends RecursiveEvaluator {
protected static final long serialVersionUID = 1L;
public EuclideanEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
public EuclideanEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
super(expression, factory, ignoredNamedParameters);
}
@Override
public Object evaluate(Tuple tuple) throws IOException {
return new EuclideanDistance();
}
@Override
public Object doWork(Object... values) throws IOException {
// Nothing to do here
throw new IOException("This call should never occur");
}
}

View File

@ -22,52 +22,16 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.TreeSet; import java.util.TreeSet;
import org.apache.commons.math3.ml.distance.CanberraDistance;
import org.apache.commons.math3.ml.distance.DistanceMeasure; import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.EarthMoversDistance;
import org.apache.commons.math3.ml.distance.EuclideanDistance; import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.commons.math3.ml.distance.ManhattanDistance;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker { public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
protected static final long serialVersionUID = 1L; protected static final long serialVersionUID = 1L;
private DistanceMeasure distanceMeasure;
public KnnEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ public KnnEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory); super(expression, factory);
DistanceEvaluator.DistanceType type = null;
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
if(namedParams.size() > 0) {
if (namedParams.size() > 1) {
throw new IOException("distance function expects only one named parameter 'distance'.");
}
StreamExpressionNamedParameter namedParameter = namedParams.get(0);
String name = namedParameter.getName();
if (!name.equalsIgnoreCase("distance")) {
throw new IOException("distance function expects only one named parameter 'distance'.");
}
String typeParam = namedParameter.getParameter().toString().trim();
type= DistanceEvaluator.DistanceType.valueOf(typeParam);
} else {
type = DistanceEvaluator.DistanceType.euclidean;
}
if (type.equals(DistanceEvaluator.DistanceType.euclidean)) {
distanceMeasure = new EuclideanDistance();
} else if (type.equals(DistanceEvaluator.DistanceType.manhattan)) {
distanceMeasure = new ManhattanDistance();
} else if (type.equals(DistanceEvaluator.DistanceType.canberra)) {
distanceMeasure = new CanberraDistance();
} else if (type.equals(DistanceEvaluator.DistanceType.earthMovers)) {
distanceMeasure = new EarthMoversDistance();
}
} }
@Override @Override
@ -105,6 +69,14 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
double[][] data = matrix.getData(); double[][] data = matrix.getData();
DistanceMeasure distanceMeasure = null;
if(values.length == 4) {
distanceMeasure = (DistanceMeasure)values[3];
} else {
distanceMeasure = new EuclideanDistance();
}
TreeSet<Neighbor> neighbors = new TreeSet(); TreeSet<Neighbor> neighbors = new TreeSet();
for(int i=0; i<data.length; i++) { for(int i=0; i<data.length; i++) {
double distance = distanceMeasure.compute(vec, data[i]); double distance = distanceMeasure.compute(vec, data[i]);
@ -165,6 +137,5 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
return this.distance.compareTo(neighbor.getDistance()); return this.distance.compareTo(neighbor.getDistance());
} }
} }
} }

View File

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import org.apache.commons.math3.ml.distance.ManhattanDistance;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class ManhattanEvaluator extends RecursiveEvaluator {
protected static final long serialVersionUID = 1L;
public ManhattanEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
}
public ManhattanEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
super(expression, factory, ignoredNamedParameters);
}
@Override
public Object evaluate(Tuple tuple) throws IOException {
return new ManhattanDistance();
}
@Override
public Object doWork(Object... values) throws IOException {
// Nothing to do here
throw new IOException("This call should never occur");
}
}

View File

@ -68,7 +68,8 @@ public class TestLang extends LuceneTestCase {
TemporalEvaluatorEpoch.FUNCTION_NAME, TemporalEvaluatorWeek.FUNCTION_NAME, TemporalEvaluatorQuarter.FUNCTION_NAME, TemporalEvaluatorEpoch.FUNCTION_NAME, TemporalEvaluatorWeek.FUNCTION_NAME, TemporalEvaluatorQuarter.FUNCTION_NAME,
TemporalEvaluatorDayOfQuarter.FUNCTION_NAME, "abs", "add", "div", "mult", "sub", "log", "pow", TemporalEvaluatorDayOfQuarter.FUNCTION_NAME, "abs", "add", "div", "mult", "sub", "log", "pow",
"mod", "ceil", "floor", "sin", "asin", "sinh", "cos", "acos", "cosh", "tan", "atan", "tanh", "round", "sqrt", "mod", "ceil", "floor", "sin", "asin", "sinh", "cos", "acos", "cosh", "tan", "atan", "tanh", "round", "sqrt",
"cbrt", "coalesce", "uuid", "if", "convert", "valueAt", "memset", "fft", "ifft"}; "cbrt", "coalesce", "uuid", "if", "convert", "valueAt", "memset", "fft", "ifft", "euclidean","manhattan",
"earthMovers", "canberra", "chebyshev"};
@Test @Test
public void testLang() { public void testLang() {

View File

@ -481,21 +481,21 @@ public class MathExpressionTest extends SolrCloudTestCase {
"f=distance(b, c)," + "f=distance(b, c)," +
"g=transpose(matrix(a, b, c))," + "g=transpose(matrix(a, b, c))," +
"h=distance(g)," + "h=distance(g)," +
"i=distance(a, b, type=manhattan), " + "i=distance(a, b, manhattan()), " +
"j=distance(a, c, type=manhattan)," + "j=distance(a, c, manhattan())," +
"k=distance(b, c, type=manhattan)," + "k=distance(b, c, manhattan())," +
"l=transpose(matrix(a, b, c))," + "l=transpose(matrix(a, b, c))," +
"m=distance(l, type=manhattan)," + "m=distance(l, manhattan())," +
"n=distance(a, b, type=canberra), " + "n=distance(a, b, canberra()), " +
"o=distance(a, c, type=canberra)," + "o=distance(a, c, canberra())," +
"p=distance(b, c, type=canberra)," + "p=distance(b, c, canberra())," +
"q=transpose(matrix(a, b, c))," + "q=transpose(matrix(a, b, c))," +
"r=distance(q, type=canberra)," + "r=distance(q, canberra())," +
"s=distance(a, b, type=earthMovers), " + "s=distance(a, b, earthMovers()), " +
"t=distance(a, c, type=earthMovers)," + "t=distance(a, c, earthMovers())," +
"u=distance(b, c, type=earthMovers)," + "u=distance(b, c, earthMovers())," +
"w=transpose(matrix(a, b, c))," + "w=transpose(matrix(a, b, c))," +
"x=distance(w, type=earthMovers)," + "x=distance(w, earthMovers())," +
")"; ")";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
@ -2946,7 +2946,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
" c=knn(a, b, 2),"+ " c=knn(a, b, 2),"+
" d=getRowLabels(c),"+ " d=getRowLabels(c),"+
" e=getAttributes(c)," + " e=getAttributes(c)," +
" f=knn(a, b, 2, distance=manhattan)," + " f=knn(a, b, 2, manhattan())," +
" g=getAttributes(f))"; " g=getAttributes(f))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr); paramsLoc.set("expr", cexpr);