mirror of https://github.com/apache/lucene.git
SOLR-12054: ebeAdd and ebeSubtract should support matrix operations
This commit is contained in:
parent
97299ed006
commit
dc5db9b2f1
|
@ -22,10 +22,12 @@ import java.util.List;
|
|||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.util.MathArrays;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class EBEAddEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||
public class EBEAddEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public EBEAddEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
|
@ -40,23 +42,28 @@ public class EBEAddEvaluator extends RecursiveNumericEvaluator implements TwoVal
|
|||
if(null == second){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(!(first instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
if(!(second instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
|
||||
double[] result = MathArrays.ebeAdd(
|
||||
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
|
||||
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
|
||||
);
|
||||
if(first instanceof List && second instanceof List) {
|
||||
double[] result = MathArrays.ebeAdd(
|
||||
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
|
||||
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
|
||||
);
|
||||
|
||||
List<Number> numbers = new ArrayList();
|
||||
for(double d : result) {
|
||||
numbers.add(d);
|
||||
List<Number> numbers = new ArrayList();
|
||||
for (double d : result) {
|
||||
numbers.add(d);
|
||||
}
|
||||
|
||||
return numbers;
|
||||
} else if(first instanceof Matrix && second instanceof Matrix) {
|
||||
double[][] data1 = ((Matrix) first).getData();
|
||||
double[][] data2 = ((Matrix) second).getData();
|
||||
Array2DRowRealMatrix matrix1 = new Array2DRowRealMatrix(data1);
|
||||
Array2DRowRealMatrix matrix2 = new Array2DRowRealMatrix(data2);
|
||||
RealMatrix matrix3 = matrix1.add(matrix2);
|
||||
return new Matrix(matrix3.getData());
|
||||
} else {
|
||||
throw new IOException("Parameters for ebeAdd must either be two numeric arrays or two matrices. ");
|
||||
}
|
||||
|
||||
return numbers;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,11 +21,13 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
||||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.apache.commons.math3.util.MathArrays;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class EBESubtractEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
|
||||
public class EBESubtractEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
|
||||
public EBESubtractEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
|
@ -40,23 +42,27 @@ public class EBESubtractEvaluator extends RecursiveNumericEvaluator implements T
|
|||
if(null == second){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
|
||||
}
|
||||
if(!(first instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
if(!(second instanceof List<?>)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
|
||||
}
|
||||
if(first instanceof List && second instanceof List) {
|
||||
double[] result = MathArrays.ebeSubtract(
|
||||
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
|
||||
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
|
||||
);
|
||||
|
||||
double[] result = MathArrays.ebeSubtract(
|
||||
((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
|
||||
((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
|
||||
);
|
||||
List<Number> numbers = new ArrayList();
|
||||
for (double d : result) {
|
||||
numbers.add(d);
|
||||
}
|
||||
|
||||
List<Number> numbers = new ArrayList();
|
||||
for(double d : result) {
|
||||
numbers.add(d);
|
||||
return numbers;
|
||||
} else if(first instanceof Matrix && second instanceof Matrix) {
|
||||
double[][] data1 = ((Matrix) first).getData();
|
||||
double[][] data2 = ((Matrix) second).getData();
|
||||
Array2DRowRealMatrix matrix1 = new Array2DRowRealMatrix(data1);
|
||||
Array2DRowRealMatrix matrix2 = new Array2DRowRealMatrix(data2);
|
||||
RealMatrix matrix3 = matrix1.subtract(matrix2);
|
||||
return new Matrix(matrix3.getData());
|
||||
} else {
|
||||
throw new IOException("Parameters for ebeSubtract must either be two numeric arrays or two matrices. ");
|
||||
}
|
||||
|
||||
return numbers;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6975,9 +6975,19 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
assertEquals(termVectors.get(0).size(), 0);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testEBESubtract() throws Exception {
|
||||
String cexpr = "ebeSubtract(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
|
||||
public void testEbeSubtract() throws Exception {
|
||||
String cexpr = "let(echo=true," +
|
||||
" a=array(2, 4, 6, 8, 10, 12)," +
|
||||
" b=array(1, 2, 3, 4, 5, 6)," +
|
||||
" c=ebeSubtract(a,b)," +
|
||||
" d=array(10, 11, 12, 13, 14, 15)," +
|
||||
" e=array(100, 200, 300, 400, 500, 600)," +
|
||||
" f=matrix(a, b)," +
|
||||
" g=matrix(d, e)," +
|
||||
" h=ebeSubtract(f, g))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
|
@ -6987,16 +6997,35 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
List<Number> out = (List<Number>)tuples.get(0).get("return-value");
|
||||
assertTrue(out.size() == 6);
|
||||
assertTrue(out.get(0).intValue() == 1);
|
||||
assertTrue(out.get(1).intValue() == 2);
|
||||
assertTrue(out.get(2).intValue() == 3);
|
||||
assertTrue(out.get(3).intValue() == 4);
|
||||
assertTrue(out.get(4).intValue() == 5);
|
||||
assertTrue(out.get(5).intValue() == 6);
|
||||
}
|
||||
List<Number> out = (List<Number>)tuples.get(0).get("c");
|
||||
assertEquals(out.size(), 6);
|
||||
assertEquals(out.get(0).doubleValue(), 1.0, 0.0);
|
||||
assertEquals(out.get(1).doubleValue(), 2.0, 0.0);
|
||||
assertEquals(out.get(2).doubleValue(), 3.0, 0.0);
|
||||
assertEquals(out.get(3).doubleValue(), 4.0, 0.0);
|
||||
assertEquals(out.get(4).doubleValue(), 5.0, 0.0);
|
||||
assertEquals(out.get(5).doubleValue(), 6.0, 0.0);
|
||||
|
||||
List<List<Number>> mout = (List<List<Number>>)tuples.get(0).get("h");
|
||||
assertEquals(mout.size(), 2);
|
||||
List<Number> row1 = mout.get(0);
|
||||
assertEquals(row1.size(), 6);
|
||||
assertEquals(row1.get(0).doubleValue(), -8.0, 0.0);
|
||||
assertEquals(row1.get(1).doubleValue(), -7.0, 0.0);
|
||||
assertEquals(row1.get(2).doubleValue(), -6.0, 0.0);
|
||||
assertEquals(row1.get(3).doubleValue(), -5.0, 0.0);
|
||||
assertEquals(row1.get(4).doubleValue(), -4.0, 0.0);
|
||||
assertEquals(row1.get(5).doubleValue(), -3.0, 0.0);
|
||||
|
||||
List<Number> row2 = mout.get(1);
|
||||
assertEquals(row2.size(), 6);
|
||||
assertEquals(row2.get(0).doubleValue(), -99.0, 0.0);
|
||||
assertEquals(row2.get(1).doubleValue(), -198.0, 0.0);
|
||||
assertEquals(row2.get(2).doubleValue(), -297.0, 0.0);
|
||||
assertEquals(row2.get(3).doubleValue(), -396.0, 0.0);
|
||||
assertEquals(row2.get(4).doubleValue(), -495.0, 0.0);
|
||||
assertEquals(row2.get(5).doubleValue(), -594.0, 0.0);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
|
@ -7341,7 +7370,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testEBEMultiply() throws Exception {
|
||||
public void testEbeMultiply() throws Exception {
|
||||
String cexpr = "ebeMultiply(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
|
@ -7364,8 +7393,16 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
|
||||
|
||||
@Test
|
||||
public void testEBEAdd() throws Exception {
|
||||
String cexpr = "ebeAdd(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
|
||||
public void testEbeAdd() throws Exception {
|
||||
String cexpr = "let(echo=true," +
|
||||
" a=array(2, 4, 6, 8, 10, 12)," +
|
||||
" b=array(1, 2, 3, 4, 5, 6)," +
|
||||
" c=ebeAdd(a,b)," +
|
||||
" d=array(10, 11, 12, 13, 14, 15)," +
|
||||
" e=array(100, 200, 300, 400, 500, 600)," +
|
||||
" f=matrix(a, b)," +
|
||||
" g=matrix(d, e)," +
|
||||
" h=ebeAdd(f, g))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
|
@ -7375,19 +7412,39 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
List<Number> out = (List<Number>)tuples.get(0).get("return-value");
|
||||
assertTrue(out.size() == 6);
|
||||
assertTrue(out.get(0).intValue() == 3);
|
||||
assertTrue(out.get(1).intValue() == 6);
|
||||
assertTrue(out.get(2).intValue() == 9);
|
||||
assertTrue(out.get(3).intValue() == 12);
|
||||
assertTrue(out.get(4).intValue() == 15);
|
||||
assertTrue(out.get(5).intValue() == 18);
|
||||
List<Number> out = (List<Number>)tuples.get(0).get("c");
|
||||
assertEquals(out.size(), 6);
|
||||
assertEquals(out.get(0).doubleValue(), 3.0, 0.0);
|
||||
assertEquals(out.get(1).doubleValue(), 6.0, 0.0);
|
||||
assertEquals(out.get(2).doubleValue(), 9.0, 0.0);
|
||||
assertEquals(out.get(3).doubleValue(), 12.0, 0.0);
|
||||
assertEquals(out.get(4).doubleValue(), 15.0, 0.0);
|
||||
assertEquals(out.get(5).doubleValue(), 18.0, 0.0);
|
||||
|
||||
List<List<Number>> mout = (List<List<Number>>)tuples.get(0).get("h");
|
||||
assertEquals(mout.size(), 2);
|
||||
List<Number> row1 = mout.get(0);
|
||||
assertEquals(row1.size(), 6);
|
||||
assertEquals(row1.get(0).doubleValue(), 12.0, 0.0);
|
||||
assertEquals(row1.get(1).doubleValue(), 15.0, 0.0);
|
||||
assertEquals(row1.get(2).doubleValue(), 18.0, 0.0);
|
||||
assertEquals(row1.get(3).doubleValue(), 21.0, 0.0);
|
||||
assertEquals(row1.get(4).doubleValue(), 24.0, 0.0);
|
||||
assertEquals(row1.get(5).doubleValue(), 27.0, 0.0);
|
||||
|
||||
List<Number> row2 = mout.get(1);
|
||||
assertEquals(row2.size(), 6);
|
||||
assertEquals(row2.get(0).doubleValue(), 101.0, 0.0);
|
||||
assertEquals(row2.get(1).doubleValue(), 202.0, 0.0);
|
||||
assertEquals(row2.get(2).doubleValue(), 303.0, 0.0);
|
||||
assertEquals(row2.get(3).doubleValue(), 404.0, 0.0);
|
||||
assertEquals(row2.get(4).doubleValue(), 505.0, 0.0);
|
||||
assertEquals(row2.get(5).doubleValue(), 606.0, 0.0);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testEBEDivide() throws Exception {
|
||||
public void testEbeDivide() throws Exception {
|
||||
String cexpr = "ebeDivide(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", cexpr);
|
||||
|
|
Loading…
Reference in New Issue