mirror of https://github.com/apache/lucene.git
SOLR-11321: Add ebeAdd, ebeSubtract, ebeDivide, ebeMultiply, dotProduct and cosineSimilarity Stream Evaluators
This commit is contained in:
parent
e782082e71
commit
3423ae4b92
|
@ -33,84 +33,7 @@ import org.apache.solr.client.solrj.io.ModelCache;
|
|||
import org.apache.solr.client.solrj.io.SolrClientCache;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.comp.StreamComparator;
|
||||
import org.apache.solr.client.solrj.io.eval.AbsoluteValueEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.AddEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.AndEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.AnovaEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.AppendEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.ArcCosineEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.ArcSineEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.ArcTangentEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.ArrayEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.AscEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.CeilingEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.CoalesceEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.ColumnEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.ConversionEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.ConvolutionEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.CopyOfEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.CopyOfRangeEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.CorrelationEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.CosineEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.CovarianceEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.CubedRootEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.CumulativeProbabilityEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.DescribeEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.DivideEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.EmpiricalDistributionEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.EqualToEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.EuclideanDistanceEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.ExclusiveOrEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.FindDelayEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.FloorEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.GreaterThanEqualToEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.GreaterThanEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.HistogramEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.HyperbolicCosineEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.HyperbolicSineEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.HyperbolicTangentEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.IfThenElseEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.KolmogorovSmirnovEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.LengthEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.LessThanEqualToEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.LessThanEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.ModuloEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.MovingAverageEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.MultiplyEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.NaturalLogEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.NormalDistributionEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.NormalizeEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.NotEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.OrEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.PercentileEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.PowerEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.PredictEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.RankEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.RawValueEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.RegressionEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.ResidualsEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.ReverseEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.RoundEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.SampleEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.ScaleEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.SequenceEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.SineEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.SquareRootEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.SubtractEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.TangentEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorDay;
|
||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorDayOfQuarter;
|
||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorDayOfYear;
|
||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorEpoch;
|
||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorHour;
|
||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorMinute;
|
||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorMonth;
|
||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorQuarter;
|
||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorSecond;
|
||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorWeek;
|
||||
import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorYear;
|
||||
import org.apache.solr.client.solrj.io.eval.UniformDistributionEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.UuidEvaluator;
|
||||
import org.apache.solr.client.solrj.io.eval.*;
|
||||
import org.apache.solr.client.solrj.io.graph.GatherNodesStream;
|
||||
import org.apache.solr.client.solrj.io.graph.ShortestPathStream;
|
||||
import org.apache.solr.client.solrj.io.ops.ConcatOperation;
|
||||
|
@ -352,6 +275,12 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
|
|||
.withFunctionName("ks", KolmogorovSmirnovEvaluator.class)
|
||||
.withFunctionName("asc", AscEvaluator.class)
|
||||
.withFunctionName("cumulativeProbability", CumulativeProbabilityEvaluator.class)
|
||||
.withFunctionName("ebeAdd", EBEAddEvaluator.class)
|
||||
.withFunctionName("ebeSubtract", EBESubtractEvaluator.class)
|
||||
.withFunctionName("ebeMultiply", EBEMultiplyEvaluator.class)
|
||||
.withFunctionName("ebeDivide", EBEDivideEvaluator.class)
|
||||
.withFunctionName("dotProduct", DotProductEvaluator.class)
|
||||
.withFunctionName("cosineSimilarity", CosineSimilarityEvaluator.class)
|
||||
|
||||
// Boolean Stream Evaluators
|
||||
.withFunctionName("and", AndEvaluator.class)
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class CosineSimilarityEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public CosineSimilarityEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object first, Object second) throws IOException{
|
||||
if(null == first){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(null == second){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(!(first instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
if(!(second instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
double[] d1 = ((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray();
|
||||
double[] d2 = ((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray();
|
||||
|
||||
return cosineSimilarity(d1, d2);
|
||||
}
|
||||
|
||||
private double cosineSimilarity(double[] vectorA, double[] vectorB) {
|
||||
double dotProduct = 0.0;
|
||||
double normA = 0.0;
|
||||
double normB = 0.0;
|
||||
for (int i = 0; i < vectorA.length; i++) {
|
||||
dotProduct += vectorA[i] * vectorB[i];
|
||||
normA += Math.pow(vectorA[i], 2);
|
||||
normB += Math.pow(vectorB[i], 2);
|
||||
}
|
||||
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.linear.RealVector;
|
||||
import org.apache.commons.math3.linear.ArrayRealVector;
|
||||
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class DotProductEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public DotProductEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object first, Object second) throws IOException{
|
||||
if(null == first){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(null == second){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(!(first instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
if(!(second instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
RealVector v = new ArrayRealVector(((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray());
|
||||
RealVector v2 = new ArrayRealVector(((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray());
|
||||
|
||||
return v.dotProduct(v2);
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.util.MathArrays;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class EBEAddEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public EBEAddEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object first, Object second) throws IOException{
|
||||
if(null == first){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(null == second){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(!(first instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
if(!(second instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
double[] result = MathArrays.ebeAdd(
|
||||
((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(),
|
||||
((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray()
|
||||
);
|
||||
|
||||
List<Number> numbers = new ArrayList();
|
||||
for(double d : result) {
|
||||
numbers.add(d);
|
||||
}
|
||||
|
||||
return numbers;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.util.MathArrays;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class EBEDivideEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public EBEDivideEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object first, Object second) throws IOException{
|
||||
if(null == first){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(null == second){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(!(first instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
if(!(second instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
double[] result = MathArrays.ebeDivide(
|
||||
((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(),
|
||||
((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray()
|
||||
);
|
||||
|
||||
List<Number> numbers = new ArrayList();
|
||||
for(double d : result) {
|
||||
numbers.add(d);
|
||||
}
|
||||
|
||||
return numbers;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.util.MathArrays;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class EBEMultiplyEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public EBEMultiplyEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object first, Object second) throws IOException{
|
||||
if(null == first){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(null == second){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(!(first instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
if(!(second instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
double[] result = MathArrays.ebeMultiply(
|
||||
((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(),
|
||||
((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray()
|
||||
);
|
||||
|
||||
List<Number> numbers = new ArrayList();
|
||||
for(double d : result) {
|
||||
numbers.add(d);
|
||||
}
|
||||
|
||||
return numbers;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.client.solrj.io.eval;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.util.MathArrays;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class EBESubtractEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public EBESubtractEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doWork(Object first, Object second) throws IOException{
|
||||
if(null == first){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(null == second){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(!(first instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
if(!(second instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
double[] result = MathArrays.ebeSubtract(
|
||||
((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(),
|
||||
((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray()
|
||||
);
|
||||
|
||||
List<Number> numbers = new ArrayList();
|
||||
for(double d : result) {
|
||||
numbers.add(d);
|
||||
}
|
||||
|
||||
return numbers;
|
||||
}
|
||||
}
|
|
@ -6064,6 +6064,130 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEBESubtract() throws Exception {
|
||||
String cexpr = "ebeSubtract(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
List<Number> out = (List<Number>)tuples.get(0).get("return-value");
|
||||
assertTrue(out.size() == 6);
|
||||
assertTrue(out.get(0).intValue() == 1);
|
||||
assertTrue(out.get(1).intValue() == 2);
|
||||
assertTrue(out.get(2).intValue() == 3);
|
||||
assertTrue(out.get(3).intValue() == 4);
|
||||
assertTrue(out.get(4).intValue() == 5);
|
||||
assertTrue(out.get(5).intValue() == 6);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testEBEMultiply() throws Exception {
|
||||
String cexpr = "ebeMultiply(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
List<Number> out = (List<Number>)tuples.get(0).get("return-value");
|
||||
assertTrue(out.size() == 6);
|
||||
assertTrue(out.get(0).intValue() == 2);
|
||||
assertTrue(out.get(1).intValue() == 8);
|
||||
assertTrue(out.get(2).intValue() == 18);
|
||||
assertTrue(out.get(3).intValue() == 32);
|
||||
assertTrue(out.get(4).intValue() == 50);
|
||||
assertTrue(out.get(5).intValue() == 72);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testEBEAdd() throws Exception {
|
||||
String cexpr = "ebeAdd(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
List<Number> out = (List<Number>)tuples.get(0).get("return-value");
|
||||
assertTrue(out.size() == 6);
|
||||
assertTrue(out.get(0).intValue() == 3);
|
||||
assertTrue(out.get(1).intValue() == 6);
|
||||
assertTrue(out.get(2).intValue() == 9);
|
||||
assertTrue(out.get(3).intValue() == 12);
|
||||
assertTrue(out.get(4).intValue() == 15);
|
||||
assertTrue(out.get(5).intValue() == 18);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testEBEDivide() throws Exception {
|
||||
String cexpr = "ebeDivide(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
List<Number> out = (List<Number>)tuples.get(0).get("return-value");
|
||||
assertTrue(out.size() == 6);
|
||||
assertTrue(out.get(0).intValue() == 2);
|
||||
assertTrue(out.get(1).intValue() == 2);
|
||||
assertTrue(out.get(2).intValue() == 2);
|
||||
assertTrue(out.get(3).intValue() == 2);
|
||||
assertTrue(out.get(4).intValue() == 2);
|
||||
assertTrue(out.get(5).intValue() == 2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCosineSimilarity() throws Exception {
|
||||
String cexpr = "cosineSimilarity(array(2,4,6,8),array(1,1,3,4))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
Number cs = (Number)tuples.get(0).get("return-value");
|
||||
assertTrue(cs.doubleValue() == 0.9838197164968291);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testDotProduct() throws Exception {
|
||||
String cexpr = "dotProduct(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
TupleStream solrStream = new SolrStream(url, paramsLoc);
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
Number dotProduct = (Number)tuples.get(0).get("return-value");
|
||||
assertTrue(dotProduct.doubleValue()== 182);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue