SOLR-12054: ebeAdd and ebeSubtract should support matrix operations

This commit is contained in:
Joel Bernstein 2018-03-04 20:22:33 -05:00
parent 97299ed006
commit dc5db9b2f1
3 changed files with 125 additions and 55 deletions

View File

@ -22,10 +22,12 @@ import java.util.List;
import java.util.Locale; import java.util.Locale;
import org.apache.commons.math3.util.MathArrays; import org.apache.commons.math3.util.MathArrays;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class EBEAddEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker { public class EBEAddEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
protected static final long serialVersionUID = 1L; protected static final long serialVersionUID = 1L;
public EBEAddEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ public EBEAddEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
@ -40,23 +42,28 @@ public class EBEAddEvaluator extends RecursiveNumericEvaluator implements TwoVal
if(null == second){ if(null == second){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
} }
if(!(first instanceof List<?>)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
}
if(!(second instanceof List<?>)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
}
double[] result = MathArrays.ebeAdd( if(first instanceof List && second instanceof List) {
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(), double[] result = MathArrays.ebeAdd(
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray() ((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
); ((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
);
List<Number> numbers = new ArrayList(); List<Number> numbers = new ArrayList();
for(double d : result) { for (double d : result) {
numbers.add(d); numbers.add(d);
}
return numbers;
} else if(first instanceof Matrix && second instanceof Matrix) {
double[][] data1 = ((Matrix) first).getData();
double[][] data2 = ((Matrix) second).getData();
Array2DRowRealMatrix matrix1 = new Array2DRowRealMatrix(data1);
Array2DRowRealMatrix matrix2 = new Array2DRowRealMatrix(data2);
RealMatrix matrix3 = matrix1.add(matrix2);
return new Matrix(matrix3.getData());
} else {
throw new IOException("Parameters for ebeAdd must either be two numeric arrays or two matrices. ");
} }
return numbers;
} }
} }

View File

@ -21,11 +21,13 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.util.MathArrays; import org.apache.commons.math3.util.MathArrays;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class EBESubtractEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker { public class EBESubtractEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
protected static final long serialVersionUID = 1L; protected static final long serialVersionUID = 1L;
public EBESubtractEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{ public EBESubtractEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
@ -40,23 +42,27 @@ public class EBESubtractEvaluator extends RecursiveNumericEvaluator implements T
if(null == second){ if(null == second){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory))); throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
} }
if(!(first instanceof List<?>)){ if(first instanceof List && second instanceof List) {
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); double[] result = MathArrays.ebeSubtract(
} ((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
if(!(second instanceof List<?>)){ ((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName())); );
}
double[] result = MathArrays.ebeSubtract( List<Number> numbers = new ArrayList();
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(), for (double d : result) {
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray() numbers.add(d);
); }
List<Number> numbers = new ArrayList(); return numbers;
for(double d : result) { } else if(first instanceof Matrix && second instanceof Matrix) {
numbers.add(d); double[][] data1 = ((Matrix) first).getData();
double[][] data2 = ((Matrix) second).getData();
Array2DRowRealMatrix matrix1 = new Array2DRowRealMatrix(data1);
Array2DRowRealMatrix matrix2 = new Array2DRowRealMatrix(data2);
RealMatrix matrix3 = matrix1.subtract(matrix2);
return new Matrix(matrix3.getData());
} else {
throw new IOException("Parameters for ebeSubtract must either be two numeric arrays or two matrices. ");
} }
return numbers;
} }
} }

View File

@ -6975,9 +6975,19 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertEquals(termVectors.get(0).size(), 0); assertEquals(termVectors.get(0).size(), 0);
} }
@Test @Test
public void testEBESubtract() throws Exception { public void testEbeSubtract() throws Exception {
String cexpr = "ebeSubtract(array(2,4,6,8,10,12),array(1,2,3,4,5,6))"; String cexpr = "let(echo=true," +
" a=array(2, 4, 6, 8, 10, 12)," +
" b=array(1, 2, 3, 4, 5, 6)," +
" c=ebeSubtract(a,b)," +
" d=array(10, 11, 12, 13, 14, 15)," +
" e=array(100, 200, 300, 400, 500, 600)," +
" f=matrix(a, b)," +
" g=matrix(d, e)," +
" h=ebeSubtract(f, g))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr); paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream"); paramsLoc.set("qt", "/stream");
@ -6987,16 +6997,35 @@ public class StreamExpressionTest extends SolrCloudTestCase {
solrStream.setStreamContext(context); solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream); List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1); assertTrue(tuples.size() == 1);
List<Number> out = (List<Number>)tuples.get(0).get("return-value"); List<Number> out = (List<Number>)tuples.get(0).get("c");
assertTrue(out.size() == 6); assertEquals(out.size(), 6);
assertTrue(out.get(0).intValue() == 1); assertEquals(out.get(0).doubleValue(), 1.0, 0.0);
assertTrue(out.get(1).intValue() == 2); assertEquals(out.get(1).doubleValue(), 2.0, 0.0);
assertTrue(out.get(2).intValue() == 3); assertEquals(out.get(2).doubleValue(), 3.0, 0.0);
assertTrue(out.get(3).intValue() == 4); assertEquals(out.get(3).doubleValue(), 4.0, 0.0);
assertTrue(out.get(4).intValue() == 5); assertEquals(out.get(4).doubleValue(), 5.0, 0.0);
assertTrue(out.get(5).intValue() == 6); assertEquals(out.get(5).doubleValue(), 6.0, 0.0);
}
List<List<Number>> mout = (List<List<Number>>)tuples.get(0).get("h");
assertEquals(mout.size(), 2);
List<Number> row1 = mout.get(0);
assertEquals(row1.size(), 6);
assertEquals(row1.get(0).doubleValue(), -8.0, 0.0);
assertEquals(row1.get(1).doubleValue(), -7.0, 0.0);
assertEquals(row1.get(2).doubleValue(), -6.0, 0.0);
assertEquals(row1.get(3).doubleValue(), -5.0, 0.0);
assertEquals(row1.get(4).doubleValue(), -4.0, 0.0);
assertEquals(row1.get(5).doubleValue(), -3.0, 0.0);
List<Number> row2 = mout.get(1);
assertEquals(row2.size(), 6);
assertEquals(row2.get(0).doubleValue(), -99.0, 0.0);
assertEquals(row2.get(1).doubleValue(), -198.0, 0.0);
assertEquals(row2.get(2).doubleValue(), -297.0, 0.0);
assertEquals(row2.get(3).doubleValue(), -396.0, 0.0);
assertEquals(row2.get(4).doubleValue(), -495.0, 0.0);
assertEquals(row2.get(5).doubleValue(), -594.0, 0.0);
}
@Test @Test
@ -7341,7 +7370,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
} }
@Test @Test
public void testEBEMultiply() throws Exception { public void testEbeMultiply() throws Exception {
String cexpr = "ebeMultiply(array(2,4,6,8,10,12),array(1,2,3,4,5,6))"; String cexpr = "ebeMultiply(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr); paramsLoc.set("expr", cexpr);
@ -7364,8 +7393,16 @@ public class StreamExpressionTest extends SolrCloudTestCase {
@Test @Test
public void testEBEAdd() throws Exception { public void testEbeAdd() throws Exception {
String cexpr = "ebeAdd(array(2,4,6,8,10,12),array(1,2,3,4,5,6))"; String cexpr = "let(echo=true," +
" a=array(2, 4, 6, 8, 10, 12)," +
" b=array(1, 2, 3, 4, 5, 6)," +
" c=ebeAdd(a,b)," +
" d=array(10, 11, 12, 13, 14, 15)," +
" e=array(100, 200, 300, 400, 500, 600)," +
" f=matrix(a, b)," +
" g=matrix(d, e)," +
" h=ebeAdd(f, g))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr); paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream"); paramsLoc.set("qt", "/stream");
@ -7375,19 +7412,39 @@ public class StreamExpressionTest extends SolrCloudTestCase {
solrStream.setStreamContext(context); solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream); List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1); assertTrue(tuples.size() == 1);
List<Number> out = (List<Number>)tuples.get(0).get("return-value"); List<Number> out = (List<Number>)tuples.get(0).get("c");
assertTrue(out.size() == 6); assertEquals(out.size(), 6);
assertTrue(out.get(0).intValue() == 3); assertEquals(out.get(0).doubleValue(), 3.0, 0.0);
assertTrue(out.get(1).intValue() == 6); assertEquals(out.get(1).doubleValue(), 6.0, 0.0);
assertTrue(out.get(2).intValue() == 9); assertEquals(out.get(2).doubleValue(), 9.0, 0.0);
assertTrue(out.get(3).intValue() == 12); assertEquals(out.get(3).doubleValue(), 12.0, 0.0);
assertTrue(out.get(4).intValue() == 15); assertEquals(out.get(4).doubleValue(), 15.0, 0.0);
assertTrue(out.get(5).intValue() == 18); assertEquals(out.get(5).doubleValue(), 18.0, 0.0);
List<List<Number>> mout = (List<List<Number>>)tuples.get(0).get("h");
assertEquals(mout.size(), 2);
List<Number> row1 = mout.get(0);
assertEquals(row1.size(), 6);
assertEquals(row1.get(0).doubleValue(), 12.0, 0.0);
assertEquals(row1.get(1).doubleValue(), 15.0, 0.0);
assertEquals(row1.get(2).doubleValue(), 18.0, 0.0);
assertEquals(row1.get(3).doubleValue(), 21.0, 0.0);
assertEquals(row1.get(4).doubleValue(), 24.0, 0.0);
assertEquals(row1.get(5).doubleValue(), 27.0, 0.0);
List<Number> row2 = mout.get(1);
assertEquals(row2.size(), 6);
assertEquals(row2.get(0).doubleValue(), 101.0, 0.0);
assertEquals(row2.get(1).doubleValue(), 202.0, 0.0);
assertEquals(row2.get(2).doubleValue(), 303.0, 0.0);
assertEquals(row2.get(3).doubleValue(), 404.0, 0.0);
assertEquals(row2.get(4).doubleValue(), 505.0, 0.0);
assertEquals(row2.get(5).doubleValue(), 606.0, 0.0);
} }
@Test @Test
public void testEBEDivide() throws Exception { public void testEbeDivide() throws Exception {
String cexpr = "ebeDivide(array(2,4,6,8,10,12),array(1,2,3,4,5,6))"; String cexpr = "ebeDivide(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams(); ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr); paramsLoc.set("expr", cexpr);