mirror of https://github.com/apache/lucene.git
SOLR-11689: Add l1norm, l2norm and linfnorm Stream Evaluators
This commit is contained in:
parent
cb88bdbee2
commit
6c0f9ac8c7
|
@ -313,6 +313,9 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
|
|||
.withFunctionName("fuzzyKmeans", FuzzyKmeansEvaluator.class)
|
||||
.withFunctionName("getMembershipMatrix", GetMembershipMatrixEvaluator.class)
|
||||
.withFunctionName("multiKmeans", MultiKmeansEvaluator.class)
|
||||
.withFunctionName("l2norm", NormEvaluator.class)
|
||||
.withFunctionName("l1norm", L1NormEvaluator.class)
|
||||
.withFunctionName("linfnorm", LInfNormEvaluator.class)
|
||||
|
||||
// Boolean Stream Evaluators
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.solr.client.solrj.io.eval;
|
|||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
@ -40,7 +41,7 @@ public class AbsoluteValueEvaluator extends RecursiveNumericEvaluator implements
|
|||
return null;
|
||||
}
|
||||
else if(value instanceof List){
|
||||
return ((List<?>)value).stream().map(innerValue -> doWork(innerValue));
|
||||
return ((List<?>)value).stream().map(innerValue -> doWork(innerValue)).collect(Collectors.toList());
|
||||
}
|
||||
else{
|
||||
return Math.abs(((Number)value).doubleValue());
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class L1NormEvaluator extends RecursiveObjectEvaluator implements OneValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public L1NormEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
|
||||
if(1 != containedEvaluators.size()){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 1 value but found %d",expression,containedEvaluators.size()));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object value) throws IOException{
|
||||
if(null == value){
|
||||
throw new IOException(String.format(Locale.ROOT, "Unable to find %s(...) because the value is null", constructingFactory.getFunctionName(getClass())));
|
||||
}
|
||||
else if(value instanceof List){
|
||||
List<Number> c = (List<Number>) value;
|
||||
double[] data = new double[c.size()];
|
||||
for(int i=0; i< c.size(); i++) {
|
||||
data[i] = c.get(i).doubleValue();
|
||||
}
|
||||
|
||||
return new ArrayRealVector(data).getL1Norm();
|
||||
}
|
||||
else{
|
||||
throw new IOException(String.format(Locale.ROOT, "Unable to find %s(...) because the value is not a collection, instead a %s was found", constructingFactory.getFunctionName(getClass()), value.getClass().getSimpleName()));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class LInfNormEvaluator extends RecursiveObjectEvaluator implements OneValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public LInfNormEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
|
||||
if(1 != containedEvaluators.size()){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 1 value but found %d",expression,containedEvaluators.size()));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object value) throws IOException{
|
||||
if(null == value){
|
||||
throw new IOException(String.format(Locale.ROOT, "Unable to find %s(...) because the value is null", constructingFactory.getFunctionName(getClass())));
|
||||
}
|
||||
else if(value instanceof List){
|
||||
List<Number> c = (List<Number>) value;
|
||||
double[] data = new double[c.size()];
|
||||
for(int i=0; i< c.size(); i++) {
|
||||
data[i] = c.get(i).doubleValue();
|
||||
}
|
||||
|
||||
return new ArrayRealVector(data).getLInfNorm();
|
||||
}
|
||||
else{
|
||||
throw new IOException(String.format(Locale.ROOT, "Unable to find %s(...) because the value is not a collection, instead a %s was found", constructingFactory.getFunctionName(getClass()), value.getClass().getSimpleName()));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class NormEvaluator extends RecursiveObjectEvaluator implements OneValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public NormEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
|
||||
if(1 != containedEvaluators.size()){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly 1 value but found %d",expression,containedEvaluators.size()));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object value) throws IOException{
|
||||
if(null == value){
|
||||
throw new IOException(String.format(Locale.ROOT, "Unable to find %s(...) because the value is null", constructingFactory.getFunctionName(getClass())));
|
||||
}
|
||||
else if(value instanceof List){
|
||||
List<Number> c = (List<Number>) value;
|
||||
double[] data = new double[c.size()];
|
||||
for(int i=0; i< c.size(); i++) {
|
||||
data[i] = c.get(i).doubleValue();
|
||||
}
|
||||
|
||||
return new ArrayRealVector(data).getNorm();
|
||||
}
|
||||
else{
|
||||
throw new IOException(String.format(Locale.ROOT, "Unable to find %s(...) because the value is not a collection, instead a %s was found", constructingFactory.getFunctionName(getClass()), value.getClass().getSimpleName()));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -17,8 +17,9 @@
|
|||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Locale;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
@ -40,10 +41,21 @@ public class PowerEvaluator extends RecursiveNumericEvaluator implements TwoValu
|
|||
if(null == first || null == second){
|
||||
return null;
|
||||
}
|
||||
|
||||
BigDecimal value = (BigDecimal)first;
|
||||
BigDecimal exponent = (BigDecimal)second;
|
||||
|
||||
return Math.pow(value.doubleValue(), exponent.doubleValue());
|
||||
|
||||
if(first instanceof Number) {
|
||||
Number value = (Number) first;
|
||||
Number exponent = (Number) second;
|
||||
return Math.pow(value.doubleValue(), exponent.doubleValue());
|
||||
} else {
|
||||
List<Number> values = (List<Number>) first;
|
||||
Number exponent = (Number) second;
|
||||
|
||||
List<Number> out = new ArrayList(values.size());
|
||||
for(Number value : values) {
|
||||
out.add(Math.pow(value.doubleValue(), exponent.doubleValue()));
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8792,6 +8792,41 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testNorms() throws Exception {
|
||||
String cexpr = "let(echo=true, " +
|
||||
" a=array(1,2,3,4,5,6), " +
|
||||
" b=l1norm(a), " +
|
||||
" c=l2norm(a), " +
|
||||
" d=linfnorm(a), " +
|
||||
" e=sqrt(add(pow(a, 2)))," +
|
||||
" f=add(abs(a)))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
Number l1norm = (Number)tuples.get(0).get("b");
|
||||
assertEquals(l1norm.doubleValue(), 21.0D, 0.0D);
|
||||
|
||||
Number norm = (Number)tuples.get(0).get("c");
|
||||
assertEquals(norm.doubleValue(), 9.5393920141695, 0.0001D);
|
||||
|
||||
Number inorm = (Number)tuples.get(0).get("d");
|
||||
assertEquals(inorm.doubleValue(), 6.0, 0.0);
|
||||
|
||||
Number norm2 = (Number)tuples.get(0).get("e");
|
||||
assertEquals(norm.doubleValue(), norm2.doubleValue(), 0.0);
|
||||
|
||||
Number l1norm2 = (Number)tuples.get(0).get("f");
|
||||
assertEquals(l1norm.doubleValue(), l1norm2.doubleValue(), 0.0);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testScale() throws Exception {
|
||||
UpdateRequest updateRequest = new UpdateRequest();
|
||||
|
|
Loading…
Reference in New Issue