mirror of https://github.com/apache/lucene.git
SOLR-11321: Add ebeAdd, ebeSubtract, ebeDivide, ebeMultiply, dotProduct and cosineSimilarity Stream Evaluators
This commit is contained in:
parent
e782082e71
commit
3423ae4b92
|
@ -33,84 +33,7 @@ import org.apache.solr.client.solrj.io.ModelCache;
|
||||||
import org.apache.solr.client.solrj.io.SolrClientCache;
|
import org.apache.solr.client.solrj.io.SolrClientCache;
|
||||||
import org.apache.solr.client.solrj.io.Tuple;
|
import org.apache.solr.client.solrj.io.Tuple;
|
||||||
import org.apache.solr.client.solrj.io.comp.StreamComparator;
|
import org.apache.solr.client.solrj.io.comp.StreamComparator;
|
||||||
import org.apache.solr.client.solrj.io.eval.AbsoluteValueEvaluator;
|
import org.apache.solr.client.solrj.io.eval.*;
|
||||||
import org.apache.solr.client.solrj.io.eval.AddEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.AndEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.AnovaEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.AppendEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.ArcCosineEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.ArcSineEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.ArcTangentEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.ArrayEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.AscEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.CeilingEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.CoalesceEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.ColumnEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.ConversionEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.ConvolutionEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.CopyOfEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.CopyOfRangeEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.CorrelationEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.CosineEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.CovarianceEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.CubedRootEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.CumulativeProbabilityEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.DescribeEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.DivideEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.EmpiricalDistributionEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.EqualToEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.EuclideanDistanceEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.ExclusiveOrEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.FindDelayEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.FloorEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.GreaterThanEqualToEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.GreaterThanEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.HistogramEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.HyperbolicCosineEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.HyperbolicSineEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.HyperbolicTangentEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.IfThenElseEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.KolmogorovSmirnovEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.LengthEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.LessThanEqualToEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.LessThanEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.ModuloEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.MovingAverageEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.MultiplyEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.NaturalLogEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.NormalDistributionEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.NormalizeEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.NotEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.OrEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.PercentileEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.PowerEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.PredictEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.RankEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.RawValueEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.RegressionEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.ResidualsEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.ReverseEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.RoundEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.SampleEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.ScaleEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.SequenceEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.SineEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.SquareRootEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.SubtractEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.TangentEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorDay;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorDayOfQuarter;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorDayOfYear;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorEpoch;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorHour;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorMinute;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorMonth;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorQuarter;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorSecond;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorWeek;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorYear;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.UniformDistributionEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.eval.UuidEvaluator;
|
|
||||||
import org.apache.solr.client.solrj.io.graph.GatherNodesStream;
|
import org.apache.solr.client.solrj.io.graph.GatherNodesStream;
|
||||||
import org.apache.solr.client.solrj.io.graph.ShortestPathStream;
|
import org.apache.solr.client.solrj.io.graph.ShortestPathStream;
|
||||||
import org.apache.solr.client.solrj.io.ops.ConcatOperation;
|
import org.apache.solr.client.solrj.io.ops.ConcatOperation;
|
||||||
|
@ -352,6 +275,12 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
|
||||||
.withFunctionName("ks", KolmogorovSmirnovEvaluator.class)
|
.withFunctionName("ks", KolmogorovSmirnovEvaluator.class)
|
||||||
.withFunctionName("asc", AscEvaluator.class)
|
.withFunctionName("asc", AscEvaluator.class)
|
||||||
.withFunctionName("cumulativeProbability", CumulativeProbabilityEvaluator.class)
|
.withFunctionName("cumulativeProbability", CumulativeProbabilityEvaluator.class)
|
||||||
|
.withFunctionName("ebeAdd", EBEAddEvaluator.class)
|
||||||
|
.withFunctionName("ebeSubtract", EBESubtractEvaluator.class)
|
||||||
|
.withFunctionName("ebeMultiply", EBEMultiplyEvaluator.class)
|
||||||
|
.withFunctionName("ebeDivide", EBEDivideEvaluator.class)
|
||||||
|
.withFunctionName("dotProduct", DotProductEvaluator.class)
|
||||||
|
.withFunctionName("cosineSimilarity", CosineSimilarityEvaluator.class)
|
||||||
|
|
||||||
// Boolean Stream Evaluators
|
// Boolean Stream Evaluators
|
||||||
.withFunctionName("and", AndEvaluator.class)
|
.withFunctionName("and", AndEvaluator.class)
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.client.solrj.io.eval;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.math.BigDecimal;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Locale;
|
||||||
|
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
|
||||||
|
public class CosineSimilarityEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||||
|
protected static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
public CosineSimilarityEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||||
|
super(expression, factory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object doWork(Object first, Object second) throws IOException{
|
||||||
|
if(null == first){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(null == second){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(!(first instanceof List<?>)){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
|
}
|
||||||
|
if(!(second instanceof List<?>)){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
|
}
|
||||||
|
|
||||||
|
double[] d1 = ((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray();
|
||||||
|
double[] d2 = ((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray();
|
||||||
|
|
||||||
|
return cosineSimilarity(d1, d2);
|
||||||
|
}
|
||||||
|
|
||||||
|
private double cosineSimilarity(double[] vectorA, double[] vectorB) {
|
||||||
|
double dotProduct = 0.0;
|
||||||
|
double normA = 0.0;
|
||||||
|
double normB = 0.0;
|
||||||
|
for (int i = 0; i < vectorA.length; i++) {
|
||||||
|
dotProduct += vectorA[i] * vectorB[i];
|
||||||
|
normA += Math.pow(vectorA[i], 2);
|
||||||
|
normB += Math.pow(vectorB[i], 2);
|
||||||
|
}
|
||||||
|
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,58 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.client.solrj.io.eval;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.math.BigDecimal;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Locale;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.linear.RealVector;
|
||||||
|
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||||
|
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
|
||||||
|
public class DotProductEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||||
|
protected static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
public DotProductEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||||
|
super(expression, factory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object doWork(Object first, Object second) throws IOException{
|
||||||
|
if(null == first){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(null == second){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(!(first instanceof List<?>)){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
|
}
|
||||||
|
if(!(second instanceof List<?>)){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
|
}
|
||||||
|
|
||||||
|
RealVector v = new ArrayRealVector(((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray());
|
||||||
|
RealVector v2 = new ArrayRealVector(((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray());
|
||||||
|
|
||||||
|
return v.dotProduct(v2);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.client.solrj.io.eval;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.math.BigDecimal;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Locale;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.util.MathArrays;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
|
||||||
|
public class EBEAddEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||||
|
protected static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
public EBEAddEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||||
|
super(expression, factory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object doWork(Object first, Object second) throws IOException{
|
||||||
|
if(null == first){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(null == second){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(!(first instanceof List<?>)){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
|
}
|
||||||
|
if(!(second instanceof List<?>)){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
|
}
|
||||||
|
|
||||||
|
double[] result = MathArrays.ebeAdd(
|
||||||
|
((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(),
|
||||||
|
((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray()
|
||||||
|
);
|
||||||
|
|
||||||
|
List<Number> numbers = new ArrayList();
|
||||||
|
for(double d : result) {
|
||||||
|
numbers.add(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
return numbers;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.client.solrj.io.eval;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.math.BigDecimal;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Locale;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.util.MathArrays;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
|
||||||
|
public class EBEDivideEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||||
|
protected static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
public EBEDivideEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||||
|
super(expression, factory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object doWork(Object first, Object second) throws IOException{
|
||||||
|
if(null == first){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(null == second){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(!(first instanceof List<?>)){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
|
}
|
||||||
|
if(!(second instanceof List<?>)){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
|
}
|
||||||
|
|
||||||
|
double[] result = MathArrays.ebeDivide(
|
||||||
|
((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(),
|
||||||
|
((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray()
|
||||||
|
);
|
||||||
|
|
||||||
|
List<Number> numbers = new ArrayList();
|
||||||
|
for(double d : result) {
|
||||||
|
numbers.add(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
return numbers;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.client.solrj.io.eval;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.math.BigDecimal;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Locale;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.util.MathArrays;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
|
||||||
|
public class EBEMultiplyEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||||
|
protected static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
public EBEMultiplyEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||||
|
super(expression, factory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object doWork(Object first, Object second) throws IOException{
|
||||||
|
if(null == first){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(null == second){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(!(first instanceof List<?>)){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
|
}
|
||||||
|
if(!(second instanceof List<?>)){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
|
}
|
||||||
|
|
||||||
|
double[] result = MathArrays.ebeMultiply(
|
||||||
|
((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(),
|
||||||
|
((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray()
|
||||||
|
);
|
||||||
|
|
||||||
|
List<Number> numbers = new ArrayList();
|
||||||
|
for(double d : result) {
|
||||||
|
numbers.add(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
return numbers;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.client.solrj.io.eval;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.math.BigDecimal;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Locale;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.util.MathArrays;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||||
|
|
||||||
|
public class EBESubtractEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||||
|
protected static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
public EBESubtractEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||||
|
super(expression, factory);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object doWork(Object first, Object second) throws IOException{
|
||||||
|
if(null == first){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(null == second){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||||
|
}
|
||||||
|
if(!(first instanceof List<?>)){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
|
}
|
||||||
|
if(!(second instanceof List<?>)){
|
||||||
|
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||||
|
}
|
||||||
|
|
||||||
|
double[] result = MathArrays.ebeSubtract(
|
||||||
|
((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(),
|
||||||
|
((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray()
|
||||||
|
);
|
||||||
|
|
||||||
|
List<Number> numbers = new ArrayList();
|
||||||
|
for(double d : result) {
|
||||||
|
numbers.add(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
return numbers;
|
||||||
|
}
|
||||||
|
}
|
|
@ -6064,6 +6064,130 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEBESubtract() throws Exception {
|
||||||
|
String cexpr = "ebeSubtract(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
|
||||||
|
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||||
|
paramsLoc.set("expr", cexpr);
|
||||||
|
paramsLoc.set("qt", "/stream");
|
||||||
|
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||||
|
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||||
|
StreamContext context = new StreamContext();
|
||||||
|
solrStream.setStreamContext(context);
|
||||||
|
List<Tuple> tuples = getTuples(solrStream);
|
||||||
|
assertTrue(tuples.size() == 1);
|
||||||
|
List<Number> out = (List<Number>)tuples.get(0).get("return-value");
|
||||||
|
assertTrue(out.size() == 6);
|
||||||
|
assertTrue(out.get(0).intValue() == 1);
|
||||||
|
assertTrue(out.get(1).intValue() == 2);
|
||||||
|
assertTrue(out.get(2).intValue() == 3);
|
||||||
|
assertTrue(out.get(3).intValue() == 4);
|
||||||
|
assertTrue(out.get(4).intValue() == 5);
|
||||||
|
assertTrue(out.get(5).intValue() == 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEBEMultiply() throws Exception {
|
||||||
|
String cexpr = "ebeMultiply(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
|
||||||
|
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||||
|
paramsLoc.set("expr", cexpr);
|
||||||
|
paramsLoc.set("qt", "/stream");
|
||||||
|
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||||
|
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||||
|
StreamContext context = new StreamContext();
|
||||||
|
solrStream.setStreamContext(context);
|
||||||
|
List<Tuple> tuples = getTuples(solrStream);
|
||||||
|
assertTrue(tuples.size() == 1);
|
||||||
|
List<Number> out = (List<Number>)tuples.get(0).get("return-value");
|
||||||
|
assertTrue(out.size() == 6);
|
||||||
|
assertTrue(out.get(0).intValue() == 2);
|
||||||
|
assertTrue(out.get(1).intValue() == 8);
|
||||||
|
assertTrue(out.get(2).intValue() == 18);
|
||||||
|
assertTrue(out.get(3).intValue() == 32);
|
||||||
|
assertTrue(out.get(4).intValue() == 50);
|
||||||
|
assertTrue(out.get(5).intValue() == 72);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEBEAdd() throws Exception {
|
||||||
|
String cexpr = "ebeAdd(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
|
||||||
|
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||||
|
paramsLoc.set("expr", cexpr);
|
||||||
|
paramsLoc.set("qt", "/stream");
|
||||||
|
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||||
|
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||||
|
StreamContext context = new StreamContext();
|
||||||
|
solrStream.setStreamContext(context);
|
||||||
|
List<Tuple> tuples = getTuples(solrStream);
|
||||||
|
assertTrue(tuples.size() == 1);
|
||||||
|
List<Number> out = (List<Number>)tuples.get(0).get("return-value");
|
||||||
|
assertTrue(out.size() == 6);
|
||||||
|
assertTrue(out.get(0).intValue() == 3);
|
||||||
|
assertTrue(out.get(1).intValue() == 6);
|
||||||
|
assertTrue(out.get(2).intValue() == 9);
|
||||||
|
assertTrue(out.get(3).intValue() == 12);
|
||||||
|
assertTrue(out.get(4).intValue() == 15);
|
||||||
|
assertTrue(out.get(5).intValue() == 18);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEBEDivide() throws Exception {
|
||||||
|
String cexpr = "ebeDivide(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
|
||||||
|
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||||
|
paramsLoc.set("expr", cexpr);
|
||||||
|
paramsLoc.set("qt", "/stream");
|
||||||
|
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||||
|
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||||
|
StreamContext context = new StreamContext();
|
||||||
|
solrStream.setStreamContext(context);
|
||||||
|
List<Tuple> tuples = getTuples(solrStream);
|
||||||
|
assertTrue(tuples.size() == 1);
|
||||||
|
List<Number> out = (List<Number>)tuples.get(0).get("return-value");
|
||||||
|
assertTrue(out.size() == 6);
|
||||||
|
assertTrue(out.get(0).intValue() == 2);
|
||||||
|
assertTrue(out.get(1).intValue() == 2);
|
||||||
|
assertTrue(out.get(2).intValue() == 2);
|
||||||
|
assertTrue(out.get(3).intValue() == 2);
|
||||||
|
assertTrue(out.get(4).intValue() == 2);
|
||||||
|
assertTrue(out.get(5).intValue() == 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCosineSimilarity() throws Exception {
|
||||||
|
String cexpr = "cosineSimilarity(array(2,4,6,8),array(1,1,3,4))";
|
||||||
|
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||||
|
paramsLoc.set("expr", cexpr);
|
||||||
|
paramsLoc.set("qt", "/stream");
|
||||||
|
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||||
|
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||||
|
StreamContext context = new StreamContext();
|
||||||
|
solrStream.setStreamContext(context);
|
||||||
|
List<Tuple> tuples = getTuples(solrStream);
|
||||||
|
assertTrue(tuples.size() == 1);
|
||||||
|
Number cs = (Number)tuples.get(0).get("return-value");
|
||||||
|
assertTrue(cs.doubleValue() == 0.9838197164968291);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDotProduct() throws Exception {
|
||||||
|
String cexpr = "dotProduct(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
|
||||||
|
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||||
|
paramsLoc.set("expr", cexpr);
|
||||||
|
paramsLoc.set("qt", "/stream");
|
||||||
|
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||||
|
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||||
|
StreamContext context = new StreamContext();
|
||||||
|
solrStream.setStreamContext(context);
|
||||||
|
List<Tuple> tuples = getTuples(solrStream);
|
||||||
|
assertTrue(tuples.size() == 1);
|
||||||
|
Number dotProduct = (Number)tuples.get(0).get("return-value");
|
||||||
|
assertTrue(dotProduct.doubleValue()== 182);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
Loading…
Reference in New Issue