mirror of https://github.com/apache/lucene.git
Merge branch 'master' of https://git-wip-us.apache.org/repos/asf/lucene-solr
This commit is contained in:
commit
d53de2a385
|
@ -191,6 +191,9 @@ Bug Fixes
|
|||
get an error that's it's a member of that alias. This check now ensures the alias state is sync()'ed with ZK first.
|
||||
(David Smiley)
|
||||
|
||||
* SOLR-12275: wrong caching for {!filters} as well as for `filters` local param in {!parent} and {!child}
|
||||
(David Smiley, Mikhail Khluldnev)
|
||||
|
||||
Optimizations
|
||||
----------------------
|
||||
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
|
||||
package org.apache.solr.search.join;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.IdentityHashMap;
|
||||
|
@ -36,31 +34,10 @@ import org.apache.solr.common.util.StrUtils;
|
|||
import org.apache.solr.request.SolrQueryRequest;
|
||||
import org.apache.solr.search.QParser;
|
||||
import org.apache.solr.search.QueryParsing;
|
||||
import org.apache.solr.search.SolrConstantScoreQuery;
|
||||
import org.apache.solr.search.SyntaxError;
|
||||
|
||||
public class FiltersQParser extends QParser {
|
||||
|
||||
private static final class FiltersQuery extends SolrConstantScoreQuery {
|
||||
|
||||
private Set<Query> filters;
|
||||
|
||||
private FiltersQuery(SolrQueryRequest req, Set<Query> filters) throws IOException {
|
||||
super(req.getSearcher().getDocSet(new ArrayList<>(filters)).getTopFilter());
|
||||
this.filters = filters;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object other) {
|
||||
return sameClassAs(other) && filters.equals(((FiltersQuery)other).filters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return 31 * classHash() + filters.hashCode();
|
||||
}
|
||||
}
|
||||
|
||||
protected String getFiltersParamName() {
|
||||
return "param";
|
||||
}
|
||||
|
@ -104,21 +81,14 @@ public class FiltersQParser extends QParser {
|
|||
|
||||
/** @return number of added clauses */
|
||||
protected int addFilters(BooleanQuery.Builder builder, Map<Query,Occur> clauses) throws SyntaxError {
|
||||
Set<Query> filters = new HashSet<>();
|
||||
int count=0;
|
||||
for (Map.Entry<Query, Occur> clause: clauses.entrySet()) {
|
||||
if (clause.getValue() == Occur.FILTER) {
|
||||
filters.add(clause.getKey());
|
||||
builder.add( clause.getKey(), Occur.FILTER);
|
||||
count++;
|
||||
}
|
||||
}
|
||||
if (!filters.isEmpty()) {
|
||||
try {
|
||||
final SolrConstantScoreQuery intersQuery = new FiltersQuery(req, filters);
|
||||
builder.add( intersQuery, Occur.FILTER);
|
||||
} catch (IOException e) {
|
||||
throw new SyntaxError("Exception occurs while parsing " + stringIncludingLocalParams, e);
|
||||
}
|
||||
}
|
||||
return filters.size();
|
||||
return count;
|
||||
}
|
||||
|
||||
protected void exclude(Map<Query,Occur> clauses) {
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.apache.solr.search.join;
|
|||
|
||||
import javax.xml.xpath.XPathConstants;
|
||||
import java.io.IOException;
|
||||
import java.lang.invoke.MethodHandles;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
@ -26,18 +27,26 @@ import java.util.ListIterator;
|
|||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.join.ScoreMode;
|
||||
import org.apache.solr.SolrTestCaseJ4;
|
||||
import org.apache.solr.common.SolrException;
|
||||
import org.apache.solr.common.SolrException.ErrorCode;
|
||||
import org.apache.solr.metrics.MetricsMap;
|
||||
import org.apache.solr.metrics.SolrMetricManager;
|
||||
import org.apache.solr.request.SolrQueryRequest;
|
||||
import org.apache.solr.search.QParser;
|
||||
import org.apache.solr.search.SyntaxError;
|
||||
import org.apache.solr.util.BaseTestHarness;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class BJQParserTest extends SolrTestCaseJ4 {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
|
||||
private static final String[] klm = new String[] {"k", "l", "m"};
|
||||
private static final List<String> xyz = Arrays.asList("x", "y", "z");
|
||||
private static final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"};
|
||||
|
@ -389,8 +398,8 @@ public class BJQParserTest extends SolrTestCaseJ4 {
|
|||
}
|
||||
|
||||
private final static String elChild[] = new String[]{"//*[@numFound='1']",
|
||||
"//doc[" +
|
||||
"arr[@name=\"child_s\"]/str='l' and child::arr[@name=\"childparent_s\"]/str='e']"};
|
||||
"//doc[" +
|
||||
"arr[@name=\"child_s\"]/str='l' and arr[@name=\"childparent_s\"]/str='e']"};
|
||||
|
||||
|
||||
@Test
|
||||
|
@ -441,10 +450,52 @@ public class BJQParserTest extends SolrTestCaseJ4 {
|
|||
"child.fq", "{!tag=secondTag}child_s:l", // 6 ls remains
|
||||
"gchq", "{!tag=top}childparent_s:e"), "//*[@numFound='6']");
|
||||
|
||||
assertQ(req("q", // top and filter are excluded, got zero result, but is it right?
|
||||
assertQ(req("q", // top and filter are excluded, got all results
|
||||
"{!filters excludeTags=bot,secondTag,top v=$gchq}" ,
|
||||
"child.fq", "{!tag=secondTag}child_s:l",
|
||||
"gchq", "{!tag=top}childparent_s:e"), "//*[@numFound='42']");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFiltersCache() throws SyntaxError, IOException {
|
||||
final String [] elFilterQuery = new String[] {"q", "{!filters param=$child.fq v=$gchq}",
|
||||
"child.fq", "childparent_s:e",
|
||||
"child.fq", "child_s:l",
|
||||
"gchq", "child_s:[* TO *]"};
|
||||
assertQ(req(elFilterQuery), elChild);
|
||||
final Query query;
|
||||
{
|
||||
final SolrQueryRequest req = req(elFilterQuery);
|
||||
try {
|
||||
QParser parser = QParser.getParser(req.getParams().get("q"), null, req);
|
||||
query = parser.getQuery();
|
||||
final TopDocs topDocs = req.getSearcher().search(query, 10);
|
||||
assertEquals(1, topDocs.totalHits);
|
||||
}finally {
|
||||
req.close();
|
||||
}
|
||||
}
|
||||
assertU(adoc("id", "12275",
|
||||
"child_s", "l", "childparent_s", "e"));
|
||||
assertU(commit());
|
||||
try {
|
||||
assertQ("here we rely on autowarming for cathing cache leak", //cache=false
|
||||
req(elFilterQuery), "//*[@numFound='2']");
|
||||
final SolrQueryRequest req = req();
|
||||
try {
|
||||
final TopDocs topDocs = req.getSearcher().search(query, 10);
|
||||
assertEquals("expecting new doc is visible to old query", 2, topDocs.totalHits);
|
||||
}finally {
|
||||
req.close();
|
||||
}
|
||||
}finally {
|
||||
try {
|
||||
assertU(delI("12275"));
|
||||
assertU(commit());
|
||||
} catch(Throwable t) {
|
||||
log.error("ignoring exception occuring in compensation routine", t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -237,6 +237,11 @@ public class Lang {
|
|||
.withFunctionName("memset", MemsetEvaluator.class)
|
||||
.withFunctionName("fft", FFTEvaluator.class)
|
||||
.withFunctionName("ifft", IFFTEvaluator.class)
|
||||
.withFunctionName("manhattan", ManhattanEvaluator.class)
|
||||
.withFunctionName("canberra", CanberraEvaluator.class)
|
||||
.withFunctionName("earthMovers", EarthMoversEvaluator.class)
|
||||
.withFunctionName("euclidean", EuclideanEvaluator.class)
|
||||
.withFunctionName("chebyshev", ChebyshevEvaluator.class)
|
||||
|
||||
// Boolean Stream Evaluators
|
||||
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.math3.ml.distance.CanberraDistance;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class CanberraEvaluator extends RecursiveEvaluator {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public CanberraEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
public CanberraEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
|
||||
super(expression, factory, ignoredNamedParameters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object evaluate(Tuple tuple) throws IOException {
|
||||
return new CanberraDistance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object... values) throws IOException {
|
||||
// Nothing to do here
|
||||
throw new IOException("This call should never occur");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.math3.ml.distance.ChebyshevDistance;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class ChebyshevEvaluator extends RecursiveEvaluator {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public ChebyshevEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
public ChebyshevEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
|
||||
super(expression, factory, ignoredNamedParameters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object evaluate(Tuple tuple) throws IOException {
|
||||
return new ChebyshevDistance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object... values) throws IOException {
|
||||
// Nothing to do here
|
||||
throw new IOException("This call should never occur");
|
||||
}
|
||||
}
|
|
@ -22,14 +22,9 @@ import java.util.List;
|
|||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math3.ml.distance.CanberraDistance;
|
||||
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||
import org.apache.commons.math3.ml.distance.EarthMoversDistance;
|
||||
import org.apache.commons.math3.ml.distance.EuclideanDistance;
|
||||
import org.apache.commons.math3.ml.distance.ManhattanDistance;
|
||||
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class DistanceEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
|
||||
|
@ -40,94 +35,75 @@ public class DistanceEvaluator extends RecursiveObjectEvaluator implements ManyV
|
|||
|
||||
public DistanceEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
|
||||
if(namedParams.size() > 0) {
|
||||
if (namedParams.size() > 1) {
|
||||
throw new IOException("distance function expects only one named parameter 'type'.");
|
||||
}
|
||||
|
||||
StreamExpressionNamedParameter namedParameter = namedParams.get(0);
|
||||
String name = namedParameter.getName();
|
||||
if (!name.equalsIgnoreCase("type")) {
|
||||
throw new IOException("distance function expects only one named parameter 'type'.");
|
||||
}
|
||||
|
||||
String typeParam = namedParameter.getParameter().toString().trim();
|
||||
this.type= DistanceType.valueOf(typeParam);
|
||||
} else {
|
||||
this.type = DistanceType.euclidean;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object ... values) throws IOException{
|
||||
|
||||
if(values.length == 2) {
|
||||
if(values.length == 1) {
|
||||
if (values[0] instanceof Matrix) {
|
||||
Matrix matrix = (Matrix) values[0];
|
||||
EuclideanDistance euclideanDistance = new EuclideanDistance();
|
||||
return distance(euclideanDistance, matrix);
|
||||
} else {
|
||||
throw new IOException("distance function operates on either two numeric arrays or a single matrix as parameters.");
|
||||
}
|
||||
} else if(values.length == 2) {
|
||||
Object first = values[0];
|
||||
Object second = values[1];
|
||||
|
||||
if (null == first) {
|
||||
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the first value", toExpression(constructingFactory)));
|
||||
}
|
||||
|
||||
if (null == second) {
|
||||
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the second value", toExpression(constructingFactory)));
|
||||
}
|
||||
|
||||
if(first instanceof Matrix) {
|
||||
Matrix matrix = (Matrix) first;
|
||||
DistanceMeasure distanceMeasure = (DistanceMeasure)second;
|
||||
return distance(distanceMeasure, matrix);
|
||||
} else {
|
||||
if (!(first instanceof List<?>)) {
|
||||
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a list of numbers", toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
if (!(second instanceof List<?>)) {
|
||||
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the second value, expecting a list of numbers", toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
DistanceMeasure distanceMeasure = new EuclideanDistance();
|
||||
return distanceMeasure.compute(
|
||||
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
|
||||
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
|
||||
);
|
||||
}
|
||||
} else if (values.length == 3) {
|
||||
Object first = values[0];
|
||||
Object second = values[1];
|
||||
DistanceMeasure distanceMeasure = (DistanceMeasure)values[2];
|
||||
|
||||
if (null == first) {
|
||||
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the first value", toExpression(constructingFactory)));
|
||||
}
|
||||
|
||||
if (null == second) {
|
||||
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the second value", toExpression(constructingFactory)));
|
||||
}
|
||||
|
||||
if (!(first instanceof List<?>)) {
|
||||
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a list of numbers", toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
if (!(second instanceof List<?>)) {
|
||||
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the second value, expecting a list of numbers", toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
if (type.equals(DistanceType.euclidean)) {
|
||||
EuclideanDistance euclideanDistance = new EuclideanDistance();
|
||||
return euclideanDistance.compute(
|
||||
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
|
||||
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
|
||||
);
|
||||
} else if (type.equals(DistanceType.manhattan)) {
|
||||
ManhattanDistance manhattanDistance = new ManhattanDistance();
|
||||
return manhattanDistance.compute(
|
||||
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
|
||||
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
|
||||
);
|
||||
|
||||
} else if (type.equals(DistanceType.canberra)) {
|
||||
CanberraDistance canberraDistance = new CanberraDistance();
|
||||
return canberraDistance.compute(
|
||||
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
|
||||
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
|
||||
);
|
||||
} else if (type.equals(DistanceType.earthMovers)) {
|
||||
EarthMoversDistance earthMoversDistance = new EarthMoversDistance();
|
||||
return earthMoversDistance.compute(
|
||||
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
|
||||
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
|
||||
);
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
} else if(values.length == 1) {
|
||||
if(values[0] instanceof Matrix) {
|
||||
Matrix matrix = (Matrix)values[0];
|
||||
if (type.equals(DistanceType.euclidean)) {
|
||||
EuclideanDistance euclideanDistance = new EuclideanDistance();
|
||||
return distance(euclideanDistance, matrix);
|
||||
} else if (type.equals(DistanceType.canberra)) {
|
||||
CanberraDistance canberraDistance = new CanberraDistance();
|
||||
return distance(canberraDistance, matrix);
|
||||
} else if (type.equals(DistanceType.manhattan)) {
|
||||
ManhattanDistance manhattanDistance = new ManhattanDistance();
|
||||
return distance(manhattanDistance, matrix);
|
||||
} else if (type.equals(DistanceType.earthMovers)) {
|
||||
EarthMoversDistance earthMoversDistance = new EarthMoversDistance();
|
||||
return distance(earthMoversDistance, matrix);
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
} else {
|
||||
throw new IOException("distance function operates on either two numeric arrays or a single matrix as parameters.");
|
||||
}
|
||||
return distanceMeasure.compute(
|
||||
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
|
||||
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
|
||||
);
|
||||
} else {
|
||||
throw new IOException("distance function operates on either two numeric arrays or a single matrix as parameters.");
|
||||
}
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.math3.ml.distance.EarthMoversDistance;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class EarthMoversEvaluator extends RecursiveEvaluator {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public EarthMoversEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
public EarthMoversEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
|
||||
super(expression, factory, ignoredNamedParameters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object evaluate(Tuple tuple) throws IOException {
|
||||
return new EarthMoversDistance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object... values) throws IOException {
|
||||
// Nothing to do here
|
||||
throw new IOException("This call should never occur");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.math3.ml.distance.EuclideanDistance;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class EuclideanEvaluator extends RecursiveEvaluator {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public EuclideanEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
public EuclideanEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
|
||||
super(expression, factory, ignoredNamedParameters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object evaluate(Tuple tuple) throws IOException {
|
||||
return new EuclideanDistance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object... values) throws IOException {
|
||||
// Nothing to do here
|
||||
throw new IOException("This call should never occur");
|
||||
}
|
||||
}
|
|
@ -22,52 +22,16 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
import java.util.TreeSet;
|
||||
|
||||
import org.apache.commons.math3.ml.distance.CanberraDistance;
|
||||
import org.apache.commons.math3.ml.distance.DistanceMeasure;
|
||||
import org.apache.commons.math3.ml.distance.EarthMoversDistance;
|
||||
import org.apache.commons.math3.ml.distance.EuclideanDistance;
|
||||
import org.apache.commons.math3.ml.distance.ManhattanDistance;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
private DistanceMeasure distanceMeasure;
|
||||
|
||||
public KnnEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
|
||||
DistanceEvaluator.DistanceType type = null;
|
||||
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
|
||||
if(namedParams.size() > 0) {
|
||||
if (namedParams.size() > 1) {
|
||||
throw new IOException("distance function expects only one named parameter 'distance'.");
|
||||
}
|
||||
|
||||
StreamExpressionNamedParameter namedParameter = namedParams.get(0);
|
||||
String name = namedParameter.getName();
|
||||
if (!name.equalsIgnoreCase("distance")) {
|
||||
throw new IOException("distance function expects only one named parameter 'distance'.");
|
||||
}
|
||||
|
||||
String typeParam = namedParameter.getParameter().toString().trim();
|
||||
type= DistanceEvaluator.DistanceType.valueOf(typeParam);
|
||||
} else {
|
||||
type = DistanceEvaluator.DistanceType.euclidean;
|
||||
}
|
||||
|
||||
if (type.equals(DistanceEvaluator.DistanceType.euclidean)) {
|
||||
distanceMeasure = new EuclideanDistance();
|
||||
} else if (type.equals(DistanceEvaluator.DistanceType.manhattan)) {
|
||||
distanceMeasure = new ManhattanDistance();
|
||||
} else if (type.equals(DistanceEvaluator.DistanceType.canberra)) {
|
||||
distanceMeasure = new CanberraDistance();
|
||||
} else if (type.equals(DistanceEvaluator.DistanceType.earthMovers)) {
|
||||
distanceMeasure = new EarthMoversDistance();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -105,6 +69,14 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
|
|||
|
||||
double[][] data = matrix.getData();
|
||||
|
||||
DistanceMeasure distanceMeasure = null;
|
||||
|
||||
if(values.length == 4) {
|
||||
distanceMeasure = (DistanceMeasure)values[3];
|
||||
} else {
|
||||
distanceMeasure = new EuclideanDistance();
|
||||
}
|
||||
|
||||
TreeSet<Neighbor> neighbors = new TreeSet();
|
||||
for(int i=0; i<data.length; i++) {
|
||||
double distance = distanceMeasure.compute(vec, data[i]);
|
||||
|
@ -165,6 +137,5 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
|
|||
return this.distance.compareTo(neighbor.getDistance());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.math3.ml.distance.ManhattanDistance;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class ManhattanEvaluator extends RecursiveEvaluator {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public ManhattanEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
public ManhattanEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
|
||||
super(expression, factory, ignoredNamedParameters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object evaluate(Tuple tuple) throws IOException {
|
||||
return new ManhattanDistance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object... values) throws IOException {
|
||||
// Nothing to do here
|
||||
throw new IOException("This call should never occur");
|
||||
}
|
||||
}
|
|
@ -68,7 +68,8 @@ public class TestLang extends LuceneTestCase {
|
|||
TemporalEvaluatorEpoch.FUNCTION_NAME, TemporalEvaluatorWeek.FUNCTION_NAME, TemporalEvaluatorQuarter.FUNCTION_NAME,
|
||||
TemporalEvaluatorDayOfQuarter.FUNCTION_NAME, "abs", "add", "div", "mult", "sub", "log", "pow",
|
||||
"mod", "ceil", "floor", "sin", "asin", "sinh", "cos", "acos", "cosh", "tan", "atan", "tanh", "round", "sqrt",
|
||||
"cbrt", "coalesce", "uuid", "if", "convert", "valueAt", "memset", "fft", "ifft"};
|
||||
"cbrt", "coalesce", "uuid", "if", "convert", "valueAt", "memset", "fft", "ifft", "euclidean","manhattan",
|
||||
"earthMovers", "canberra", "chebyshev"};
|
||||
|
||||
@Test
|
||||
public void testLang() {
|
||||
|
|
|
@ -481,21 +481,21 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
"f=distance(b, c)," +
|
||||
"g=transpose(matrix(a, b, c))," +
|
||||
"h=distance(g)," +
|
||||
"i=distance(a, b, type=manhattan), " +
|
||||
"j=distance(a, c, type=manhattan)," +
|
||||
"k=distance(b, c, type=manhattan)," +
|
||||
"i=distance(a, b, manhattan()), " +
|
||||
"j=distance(a, c, manhattan())," +
|
||||
"k=distance(b, c, manhattan())," +
|
||||
"l=transpose(matrix(a, b, c))," +
|
||||
"m=distance(l, type=manhattan)," +
|
||||
"n=distance(a, b, type=canberra), " +
|
||||
"o=distance(a, c, type=canberra)," +
|
||||
"p=distance(b, c, type=canberra)," +
|
||||
"m=distance(l, manhattan())," +
|
||||
"n=distance(a, b, canberra()), " +
|
||||
"o=distance(a, c, canberra())," +
|
||||
"p=distance(b, c, canberra())," +
|
||||
"q=transpose(matrix(a, b, c))," +
|
||||
"r=distance(q, type=canberra)," +
|
||||
"s=distance(a, b, type=earthMovers), " +
|
||||
"t=distance(a, c, type=earthMovers)," +
|
||||
"u=distance(b, c, type=earthMovers)," +
|
||||
"r=distance(q, canberra())," +
|
||||
"s=distance(a, b, earthMovers()), " +
|
||||
"t=distance(a, c, earthMovers())," +
|
||||
"u=distance(b, c, earthMovers())," +
|
||||
"w=transpose(matrix(a, b, c))," +
|
||||
"x=distance(w, type=earthMovers)," +
|
||||
"x=distance(w, earthMovers())," +
|
||||
")";
|
||||
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
|
@ -2946,7 +2946,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
|
|||
" c=knn(a, b, 2),"+
|
||||
" d=getRowLabels(c),"+
|
||||
" e=getAttributes(c)," +
|
||||
" f=knn(a, b, 2, distance=manhattan)," +
|
||||
" f=knn(a, b, 2, manhattan())," +
|
||||
" g=getAttributes(f))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
|
|
Loading…
Reference in New Issue