SOLR-11689: Add l1norm, l2norm and linfnorm Stream Evaluators

This commit is contained in:
Joel Bernstein 2018-02-16 19:06:58 -05:00
parent cb88bdbee2
commit 6c0f9ac8c7
7 changed files with 226 additions and 7 deletions

View File

@ -313,6 +313,9 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("fuzzyKmeans", FuzzyKmeansEvaluator.class)
.withFunctionName("getMembershipMatrix", GetMembershipMatrixEvaluator.class)
.withFunctionName("multiKmeans", MultiKmeansEvaluator.class)
.withFunctionName("l2norm", NormEvaluator.class)
.withFunctionName("l1norm", L1NormEvaluator.class)
.withFunctionName("linfnorm", LInfNormEvaluator.class)
// Boolean Stream Evaluators

View File

@ -19,6 +19,7 @@ package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
@ -40,7 +41,7 @@ public class AbsoluteValueEvaluator extends RecursiveNumericEvaluator implements
return null;
}
else if(value instanceof List){
return ((List<?>)value).stream().map(innerValue -> doWork(innerValue));
return ((List<?>)value).stream().map(innerValue -> doWork(innerValue)).collect(Collectors.toList());
}
else{
return Math.abs(((Number)value).doubleValue());

View File

@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class L1NormEvaluator extends RecursiveObjectEvaluator implements OneValueWorker {
protected static final long serialVersionUID = 1L;
public L1NormEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
if(1 != containedEvaluators.size()){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 1 value but found %d",expression,containedEvaluators.size()));
}
}
@Override
public Object doWork(Object value) throws IOException{
if(null == value){
throw new IOException(String.format(Locale.ROOT, "Unable to find %s(...) because the value is null", constructingFactory.getFunctionName(getClass())));
}
else if(value instanceof List){
List<Number> c = (List<Number>) value;
double[] data = new double[c.size()];
for(int i=0; i< c.size(); i++) {
data[i] = c.get(i).doubleValue();
}
return new ArrayRealVector(data).getL1Norm();
}
else{
throw new IOException(String.format(Locale.ROOT, "Unable to find %s(...) because the value is not a collection, instead a %s was found", constructingFactory.getFunctionName(getClass()), value.getClass().getSimpleName()));
}
}
}

View File

@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class LInfNormEvaluator extends RecursiveObjectEvaluator implements OneValueWorker {
protected static final long serialVersionUID = 1L;
public LInfNormEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
if(1 != containedEvaluators.size()){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 1 value but found %d",expression,containedEvaluators.size()));
}
}
@Override
public Object doWork(Object value) throws IOException{
if(null == value){
throw new IOException(String.format(Locale.ROOT, "Unable to find %s(...) because the value is null", constructingFactory.getFunctionName(getClass())));
}
else if(value instanceof List){
List<Number> c = (List<Number>) value;
double[] data = new double[c.size()];
for(int i=0; i< c.size(); i++) {
data[i] = c.get(i).doubleValue();
}
return new ArrayRealVector(data).getLInfNorm();
}
else{
throw new IOException(String.format(Locale.ROOT, "Unable to find %s(...) because the value is not a collection, instead a %s was found", constructingFactory.getFunctionName(getClass()), value.getClass().getSimpleName()));
}
}
}

View File

@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class NormEvaluator extends RecursiveObjectEvaluator implements OneValueWorker {
protected static final long serialVersionUID = 1L;
public NormEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
if(1 != containedEvaluators.size()){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 1 value but found %d",expression,containedEvaluators.size()));
}
}
@Override
public Object doWork(Object value) throws IOException{
if(null == value){
throw new IOException(String.format(Locale.ROOT, "Unable to find %s(...) because the value is null", constructingFactory.getFunctionName(getClass())));
}
else if(value instanceof List){
List<Number> c = (List<Number>) value;
double[] data = new double[c.size()];
for(int i=0; i< c.size(); i++) {
data[i] = c.get(i).doubleValue();
}
return new ArrayRealVector(data).getNorm();
}
else{
throw new IOException(String.format(Locale.ROOT, "Unable to find %s(...) because the value is not a collection, instead a %s was found", constructingFactory.getFunctionName(getClass()), value.getClass().getSimpleName()));
}
}
}

View File

@ -17,8 +17,9 @@
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Locale;
import java.util.List;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
@ -40,10 +41,21 @@ public class PowerEvaluator extends RecursiveNumericEvaluator implements TwoValu
if(null == first || null == second){
return null;
}
BigDecimal value = (BigDecimal)first;
BigDecimal exponent = (BigDecimal)second;
return Math.pow(value.doubleValue(), exponent.doubleValue());
if(first instanceof Number) {
Number value = (Number) first;
Number exponent = (Number) second;
return Math.pow(value.doubleValue(), exponent.doubleValue());
} else {
List<Number> values = (List<Number>) first;
Number exponent = (Number) second;
List<Number> out = new ArrayList(values.size());
for(Number value : values) {
out.add(Math.pow(value.doubleValue(), exponent.doubleValue()));
}
return out;
}
}
}

View File

@ -8792,6 +8792,41 @@ public class StreamExpressionTest extends SolrCloudTestCase {
}
@Test
public void testNorms() throws Exception {
String cexpr = "let(echo=true, " +
" a=array(1,2,3,4,5,6), " +
" b=l1norm(a), " +
" c=l2norm(a), " +
" d=linfnorm(a), " +
" e=sqrt(add(pow(a, 2)))," +
" f=add(abs(a)))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
TupleStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
Number l1norm = (Number)tuples.get(0).get("b");
assertEquals(l1norm.doubleValue(), 21.0D, 0.0D);
Number norm = (Number)tuples.get(0).get("c");
assertEquals(norm.doubleValue(), 9.5393920141695, 0.0001D);
Number inorm = (Number)tuples.get(0).get("d");
assertEquals(inorm.doubleValue(), 6.0, 0.0);
Number norm2 = (Number)tuples.get(0).get("e");
assertEquals(norm.doubleValue(), norm2.doubleValue(), 0.0);
Number l1norm2 = (Number)tuples.get(0).get("f");
assertEquals(l1norm.doubleValue(), l1norm2.doubleValue(), 0.0);
}
@Test
public void testScale() throws Exception {
UpdateRequest updateRequest = new UpdateRequest();