From 3423ae4b920ef51d685c9efc7f2cad390d987b87 Mon Sep 17 00:00:00 2001 From: Joel Bernstein Date: Mon, 4 Sep 2017 18:40:03 -0400 Subject: [PATCH] SOLR-11321: Add ebeAdd, ebeSubtract, ebeDivide, ebeMultiply, dotProduct and cosineSimilarity Stream Evaluators --- .../apache/solr/handler/StreamHandler.java | 85 +----------- .../io/eval/CosineSimilarityEvaluator.java | 67 ++++++++++ .../solrj/io/eval/DotProductEvaluator.java | 58 ++++++++ .../client/solrj/io/eval/EBEAddEvaluator.java | 63 +++++++++ .../solrj/io/eval/EBEDivideEvaluator.java | 63 +++++++++ .../solrj/io/eval/EBEMultiplyEvaluator.java | 63 +++++++++ .../solrj/io/eval/EBESubtractEvaluator.java | 63 +++++++++ .../solrj/io/stream/StreamExpressionTest.java | 124 ++++++++++++++++++ 8 files changed, 508 insertions(+), 78 deletions(-) create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CosineSimilarityEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/DotProductEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEAddEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEDivideEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEMultiplyEvaluator.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBESubtractEvaluator.java diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java index 8b6a6c0225d..9613ec33a06 100644 --- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java +++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java @@ -33,84 +33,7 @@ import org.apache.solr.client.solrj.io.ModelCache; import org.apache.solr.client.solrj.io.SolrClientCache; import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.comp.StreamComparator; -import org.apache.solr.client.solrj.io.eval.AbsoluteValueEvaluator; -import org.apache.solr.client.solrj.io.eval.AddEvaluator; -import org.apache.solr.client.solrj.io.eval.AndEvaluator; -import org.apache.solr.client.solrj.io.eval.AnovaEvaluator; -import org.apache.solr.client.solrj.io.eval.AppendEvaluator; -import org.apache.solr.client.solrj.io.eval.ArcCosineEvaluator; -import org.apache.solr.client.solrj.io.eval.ArcSineEvaluator; -import org.apache.solr.client.solrj.io.eval.ArcTangentEvaluator; -import org.apache.solr.client.solrj.io.eval.ArrayEvaluator; -import org.apache.solr.client.solrj.io.eval.AscEvaluator; -import org.apache.solr.client.solrj.io.eval.CeilingEvaluator; -import org.apache.solr.client.solrj.io.eval.CoalesceEvaluator; -import org.apache.solr.client.solrj.io.eval.ColumnEvaluator; -import org.apache.solr.client.solrj.io.eval.ConversionEvaluator; -import org.apache.solr.client.solrj.io.eval.ConvolutionEvaluator; -import org.apache.solr.client.solrj.io.eval.CopyOfEvaluator; -import org.apache.solr.client.solrj.io.eval.CopyOfRangeEvaluator; -import org.apache.solr.client.solrj.io.eval.CorrelationEvaluator; -import org.apache.solr.client.solrj.io.eval.CosineEvaluator; -import org.apache.solr.client.solrj.io.eval.CovarianceEvaluator; -import org.apache.solr.client.solrj.io.eval.CubedRootEvaluator; -import org.apache.solr.client.solrj.io.eval.CumulativeProbabilityEvaluator; -import org.apache.solr.client.solrj.io.eval.DescribeEvaluator; -import org.apache.solr.client.solrj.io.eval.DivideEvaluator; -import org.apache.solr.client.solrj.io.eval.EmpiricalDistributionEvaluator; -import org.apache.solr.client.solrj.io.eval.EqualToEvaluator; -import org.apache.solr.client.solrj.io.eval.EuclideanDistanceEvaluator; -import org.apache.solr.client.solrj.io.eval.ExclusiveOrEvaluator; -import org.apache.solr.client.solrj.io.eval.FindDelayEvaluator; -import org.apache.solr.client.solrj.io.eval.FloorEvaluator; -import org.apache.solr.client.solrj.io.eval.GreaterThanEqualToEvaluator; -import org.apache.solr.client.solrj.io.eval.GreaterThanEvaluator; -import org.apache.solr.client.solrj.io.eval.HistogramEvaluator; -import org.apache.solr.client.solrj.io.eval.HyperbolicCosineEvaluator; -import org.apache.solr.client.solrj.io.eval.HyperbolicSineEvaluator; -import org.apache.solr.client.solrj.io.eval.HyperbolicTangentEvaluator; -import org.apache.solr.client.solrj.io.eval.IfThenElseEvaluator; -import org.apache.solr.client.solrj.io.eval.KolmogorovSmirnovEvaluator; -import org.apache.solr.client.solrj.io.eval.LengthEvaluator; -import org.apache.solr.client.solrj.io.eval.LessThanEqualToEvaluator; -import org.apache.solr.client.solrj.io.eval.LessThanEvaluator; -import org.apache.solr.client.solrj.io.eval.ModuloEvaluator; -import org.apache.solr.client.solrj.io.eval.MovingAverageEvaluator; -import org.apache.solr.client.solrj.io.eval.MultiplyEvaluator; -import org.apache.solr.client.solrj.io.eval.NaturalLogEvaluator; -import org.apache.solr.client.solrj.io.eval.NormalDistributionEvaluator; -import org.apache.solr.client.solrj.io.eval.NormalizeEvaluator; -import org.apache.solr.client.solrj.io.eval.NotEvaluator; -import org.apache.solr.client.solrj.io.eval.OrEvaluator; -import org.apache.solr.client.solrj.io.eval.PercentileEvaluator; -import org.apache.solr.client.solrj.io.eval.PowerEvaluator; -import org.apache.solr.client.solrj.io.eval.PredictEvaluator; -import org.apache.solr.client.solrj.io.eval.RankEvaluator; -import org.apache.solr.client.solrj.io.eval.RawValueEvaluator; -import org.apache.solr.client.solrj.io.eval.RegressionEvaluator; -import org.apache.solr.client.solrj.io.eval.ResidualsEvaluator; -import org.apache.solr.client.solrj.io.eval.ReverseEvaluator; -import org.apache.solr.client.solrj.io.eval.RoundEvaluator; -import org.apache.solr.client.solrj.io.eval.SampleEvaluator; -import org.apache.solr.client.solrj.io.eval.ScaleEvaluator; -import org.apache.solr.client.solrj.io.eval.SequenceEvaluator; -import org.apache.solr.client.solrj.io.eval.SineEvaluator; -import org.apache.solr.client.solrj.io.eval.SquareRootEvaluator; -import org.apache.solr.client.solrj.io.eval.SubtractEvaluator; -import org.apache.solr.client.solrj.io.eval.TangentEvaluator; -import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorDay; -import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorDayOfQuarter; -import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorDayOfYear; -import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorEpoch; -import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorHour; -import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorMinute; -import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorMonth; -import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorQuarter; -import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorSecond; -import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorWeek; -import org.apache.solr.client.solrj.io.eval.TemporalEvaluatorYear; -import org.apache.solr.client.solrj.io.eval.UniformDistributionEvaluator; -import org.apache.solr.client.solrj.io.eval.UuidEvaluator; +import org.apache.solr.client.solrj.io.eval.*; import org.apache.solr.client.solrj.io.graph.GatherNodesStream; import org.apache.solr.client.solrj.io.graph.ShortestPathStream; import org.apache.solr.client.solrj.io.ops.ConcatOperation; @@ -352,6 +275,12 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, .withFunctionName("ks", KolmogorovSmirnovEvaluator.class) .withFunctionName("asc", AscEvaluator.class) .withFunctionName("cumulativeProbability", CumulativeProbabilityEvaluator.class) + .withFunctionName("ebeAdd", EBEAddEvaluator.class) + .withFunctionName("ebeSubtract", EBESubtractEvaluator.class) + .withFunctionName("ebeMultiply", EBEMultiplyEvaluator.class) + .withFunctionName("ebeDivide", EBEDivideEvaluator.class) + .withFunctionName("dotProduct", DotProductEvaluator.class) + .withFunctionName("cosineSimilarity", CosineSimilarityEvaluator.class) // Boolean Stream Evaluators .withFunctionName("and", AndEvaluator.class) diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CosineSimilarityEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CosineSimilarityEvaluator.java new file mode 100644 index 00000000000..ea88400b7a6 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CosineSimilarityEvaluator.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.List; +import java.util.Locale; + +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class CosineSimilarityEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker { + protected static final long serialVersionUID = 1L; + + public CosineSimilarityEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + } + + @Override + public Object doWork(Object first, Object second) throws IOException{ + if(null == first){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory))); + } + if(null == second){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); + } + if(!(first instanceof List)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); + } + if(!(second instanceof List)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); + } + + double[] d1 = ((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(); + double[] d2 = ((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(); + + return cosineSimilarity(d1, d2); + } + + private double cosineSimilarity(double[] vectorA, double[] vectorB) { + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + for (int i = 0; i < vectorA.length; i++) { + dotProduct += vectorA[i] * vectorB[i]; + normA += Math.pow(vectorA[i], 2); + normB += Math.pow(vectorB[i], 2); + } + return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + } + +} diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/DotProductEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/DotProductEvaluator.java new file mode 100644 index 00000000000..3133bac7d1c --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/DotProductEvaluator.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.List; +import java.util.Locale; + +import org.apache.commons.math3.linear.RealVector; +import org.apache.commons.math3.linear.ArrayRealVector; + +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class DotProductEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker { + protected static final long serialVersionUID = 1L; + + public DotProductEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + } + + @Override + public Object doWork(Object first, Object second) throws IOException{ + if(null == first){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory))); + } + if(null == second){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); + } + if(!(first instanceof List)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); + } + if(!(second instanceof List)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); + } + + RealVector v = new ArrayRealVector(((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray()); + RealVector v2 = new ArrayRealVector(((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray()); + + return v.dotProduct(v2); + + } +} diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEAddEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEAddEvaluator.java new file mode 100644 index 00000000000..c1eec9b71a3 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEAddEvaluator.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +import org.apache.commons.math3.util.MathArrays; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class EBEAddEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker { + protected static final long serialVersionUID = 1L; + + public EBEAddEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + } + + @Override + public Object doWork(Object first, Object second) throws IOException{ + if(null == first){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory))); + } + if(null == second){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); + } + if(!(first instanceof List)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); + } + if(!(second instanceof List)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); + } + + double[] result = MathArrays.ebeAdd( + ((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(), + ((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray() + ); + + List numbers = new ArrayList(); + for(double d : result) { + numbers.add(d); + } + + return numbers; + } +} diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEDivideEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEDivideEvaluator.java new file mode 100644 index 00000000000..c457f68795a --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEDivideEvaluator.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +import org.apache.commons.math3.util.MathArrays; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class EBEDivideEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker { + protected static final long serialVersionUID = 1L; + + public EBEDivideEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + } + + @Override + public Object doWork(Object first, Object second) throws IOException{ + if(null == first){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory))); + } + if(null == second){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); + } + if(!(first instanceof List)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); + } + if(!(second instanceof List)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); + } + + double[] result = MathArrays.ebeDivide( + ((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(), + ((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray() + ); + + List numbers = new ArrayList(); + for(double d : result) { + numbers.add(d); + } + + return numbers; + } +} diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEMultiplyEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEMultiplyEvaluator.java new file mode 100644 index 00000000000..b3617cdd37f --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEMultiplyEvaluator.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +import org.apache.commons.math3.util.MathArrays; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class EBEMultiplyEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker { + protected static final long serialVersionUID = 1L; + + public EBEMultiplyEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + } + + @Override + public Object doWork(Object first, Object second) throws IOException{ + if(null == first){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory))); + } + if(null == second){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); + } + if(!(first instanceof List)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); + } + if(!(second instanceof List)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); + } + + double[] result = MathArrays.ebeMultiply( + ((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(), + ((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray() + ); + + List numbers = new ArrayList(); + for(double d : result) { + numbers.add(d); + } + + return numbers; + } +} diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBESubtractEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBESubtractEvaluator.java new file mode 100644 index 00000000000..2f2f0223bb3 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBESubtractEvaluator.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io.eval; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +import org.apache.commons.math3.util.MathArrays; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +public class EBESubtractEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker { + protected static final long serialVersionUID = 1L; + + public EBESubtractEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ + super(expression, factory); + } + + @Override + public Object doWork(Object first, Object second) throws IOException{ + if(null == first){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory))); + } + if(null == second){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); + } + if(!(first instanceof List)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); + } + if(!(second instanceof List)){ + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); + } + + double[] result = MathArrays.ebeSubtract( + ((List) first).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray(), + ((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray() + ); + + List numbers = new ArrayList(); + for(double d : result) { + numbers.add(d); + } + + return numbers; + } +} diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java index 670b39d4fd0..f831daca740 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java @@ -6064,6 +6064,130 @@ public class StreamExpressionTest extends SolrCloudTestCase { } } + @Test + public void testEBESubtract() throws Exception { + String cexpr = "ebeSubtract(array(2,4,6,8,10,12),array(1,2,3,4,5,6))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + List out = (List)tuples.get(0).get("return-value"); + assertTrue(out.size() == 6); + assertTrue(out.get(0).intValue() == 1); + assertTrue(out.get(1).intValue() == 2); + assertTrue(out.get(2).intValue() == 3); + assertTrue(out.get(3).intValue() == 4); + assertTrue(out.get(4).intValue() == 5); + assertTrue(out.get(5).intValue() == 6); + } + + + @Test + public void testEBEMultiply() throws Exception { + String cexpr = "ebeMultiply(array(2,4,6,8,10,12),array(1,2,3,4,5,6))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + List out = (List)tuples.get(0).get("return-value"); + assertTrue(out.size() == 6); + assertTrue(out.get(0).intValue() == 2); + assertTrue(out.get(1).intValue() == 8); + assertTrue(out.get(2).intValue() == 18); + assertTrue(out.get(3).intValue() == 32); + assertTrue(out.get(4).intValue() == 50); + assertTrue(out.get(5).intValue() == 72); + } + + + @Test + public void testEBEAdd() throws Exception { + String cexpr = "ebeAdd(array(2,4,6,8,10,12),array(1,2,3,4,5,6))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + List out = (List)tuples.get(0).get("return-value"); + assertTrue(out.size() == 6); + assertTrue(out.get(0).intValue() == 3); + assertTrue(out.get(1).intValue() == 6); + assertTrue(out.get(2).intValue() == 9); + assertTrue(out.get(3).intValue() == 12); + assertTrue(out.get(4).intValue() == 15); + assertTrue(out.get(5).intValue() == 18); + } + + + @Test + public void testEBEDivide() throws Exception { + String cexpr = "ebeDivide(array(2,4,6,8,10,12),array(1,2,3,4,5,6))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + List out = (List)tuples.get(0).get("return-value"); + assertTrue(out.size() == 6); + assertTrue(out.get(0).intValue() == 2); + assertTrue(out.get(1).intValue() == 2); + assertTrue(out.get(2).intValue() == 2); + assertTrue(out.get(3).intValue() == 2); + assertTrue(out.get(4).intValue() == 2); + assertTrue(out.get(5).intValue() == 2); + } + + @Test + public void testCosineSimilarity() throws Exception { + String cexpr = "cosineSimilarity(array(2,4,6,8),array(1,1,3,4))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + Number cs = (Number)tuples.get(0).get("return-value"); + assertTrue(cs.doubleValue() == 0.9838197164968291); + } + + + + @Test + public void testDotProduct() throws Exception { + String cexpr = "dotProduct(array(2,4,6,8,10,12),array(1,2,3,4,5,6))"; + ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); + paramsLoc.set("expr", cexpr); + paramsLoc.set("qt", "/stream"); + String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS; + TupleStream solrStream = new SolrStream(url, paramsLoc); + StreamContext context = new StreamContext(); + solrStream.setStreamContext(context); + List tuples = getTuples(solrStream); + assertTrue(tuples.size() == 1); + Number dotProduct = (Number)tuples.get(0).get("return-value"); + assertTrue(dotProduct.doubleValue()== 182); + } @Test