dupirefr/dupire.francois+pro@gmail.com [BAEL-3606] Matrix Multiplication Libraries Comparison (#8298)
* Added benchmarking on larger matrices * [BAEL-3606] Moved benchmarking to production code * [BAEL-3606] Added minor fix
This commit is contained in:
parent
bd87e517dc
commit
1b81033840
|
@ -0,0 +1,121 @@
|
|||
package com.baeldung.matrices.benchmark;
|
||||
|
||||
import cern.colt.matrix.DoubleFactory2D;
|
||||
import cern.colt.matrix.DoubleMatrix2D;
|
||||
import cern.colt.matrix.linalg.Algebra;
|
||||
import com.baeldung.matrices.HomemadeMatrix;
|
||||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
import org.la4j.Matrix;
|
||||
import org.la4j.matrix.dense.Basic2DMatrix;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.openjdk.jmh.annotations.Benchmark;
|
||||
import org.openjdk.jmh.annotations.Mode;
|
||||
import org.openjdk.jmh.runner.Runner;
|
||||
import org.openjdk.jmh.runner.options.ChainedOptionsBuilder;
|
||||
import org.openjdk.jmh.runner.options.OptionsBuilder;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class BigMatrixMultiplicationBenchmarking {
|
||||
private static final int DEFAULT_FORKS = 2;
|
||||
private static final int DEFAULT_WARMUP_ITERATIONS = 5;
|
||||
private static final int DEFAULT_MEASUREMENT_ITERATIONS = 10;
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
Map<String, String> parameters = parseParameters(args);
|
||||
|
||||
ChainedOptionsBuilder builder = new OptionsBuilder()
|
||||
.include(BigMatrixMultiplicationBenchmarking.class.getSimpleName())
|
||||
.mode(Mode.AverageTime)
|
||||
.forks(forks(parameters))
|
||||
.warmupIterations(warmupIterations(parameters))
|
||||
.measurementIterations(measurementIterations(parameters))
|
||||
.timeUnit(TimeUnit.SECONDS);
|
||||
|
||||
parameters.forEach(builder::param);
|
||||
|
||||
new Runner(builder.build()).run();
|
||||
}
|
||||
|
||||
private static Map<String, String> parseParameters(String[] args) {
|
||||
return Arrays.stream(args)
|
||||
.map(arg -> arg.split("="))
|
||||
.collect(Collectors.toMap(
|
||||
arg -> arg[0],
|
||||
arg -> arg[1]
|
||||
));
|
||||
}
|
||||
|
||||
private static int forks(Map<String, String> parameters) {
|
||||
String forks = parameters.remove("forks");
|
||||
return parseOrDefault(forks, DEFAULT_FORKS);
|
||||
}
|
||||
|
||||
private static int warmupIterations(Map<String, String> parameters) {
|
||||
String warmups = parameters.remove("warmupIterations");
|
||||
return parseOrDefault(warmups, DEFAULT_WARMUP_ITERATIONS);
|
||||
}
|
||||
|
||||
private static int measurementIterations(Map<String, String> parameters) {
|
||||
String measurements = parameters.remove("measurementIterations");
|
||||
return parseOrDefault(measurements, DEFAULT_MEASUREMENT_ITERATIONS);
|
||||
}
|
||||
|
||||
private static int parseOrDefault(String parameter, int defaultValue) {
|
||||
return parameter != null ? Integer.parseInt(parameter) : defaultValue;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public Object homemadeMatrixMultiplication(BigMatrixProvider matrixProvider) {
|
||||
return HomemadeMatrix.multiplyMatrices(matrixProvider.getFirstMatrix(), matrixProvider.getSecondMatrix());
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public Object ejmlMatrixMultiplication(BigMatrixProvider matrixProvider) {
|
||||
SimpleMatrix firstMatrix = new SimpleMatrix(matrixProvider.getFirstMatrix());
|
||||
SimpleMatrix secondMatrix = new SimpleMatrix(matrixProvider.getSecondMatrix());
|
||||
|
||||
return firstMatrix.mult(secondMatrix);
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public Object apacheCommonsMatrixMultiplication(BigMatrixProvider matrixProvider) {
|
||||
RealMatrix firstMatrix = new Array2DRowRealMatrix(matrixProvider.getFirstMatrix());
|
||||
RealMatrix secondMatrix = new Array2DRowRealMatrix(matrixProvider.getSecondMatrix());
|
||||
|
||||
return firstMatrix.multiply(secondMatrix);
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public Object la4jMatrixMultiplication(BigMatrixProvider matrixProvider) {
|
||||
Matrix firstMatrix = new Basic2DMatrix(matrixProvider.getFirstMatrix());
|
||||
Matrix secondMatrix = new Basic2DMatrix(matrixProvider.getSecondMatrix());
|
||||
|
||||
return firstMatrix.multiply(secondMatrix);
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public Object nd4jMatrixMultiplication(BigMatrixProvider matrixProvider) {
|
||||
INDArray firstMatrix = Nd4j.create(matrixProvider.getFirstMatrix());
|
||||
INDArray secondMatrix = Nd4j.create(matrixProvider.getSecondMatrix());
|
||||
|
||||
return firstMatrix.mmul(secondMatrix);
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public Object coltMatrixMultiplication(BigMatrixProvider matrixProvider) {
|
||||
DoubleFactory2D doubleFactory2D = DoubleFactory2D.dense;
|
||||
|
||||
DoubleMatrix2D firstMatrix = doubleFactory2D.make(matrixProvider.getFirstMatrix());
|
||||
DoubleMatrix2D secondMatrix = doubleFactory2D.make(matrixProvider.getSecondMatrix());
|
||||
|
||||
Algebra algebra = new Algebra();
|
||||
return algebra.mult(firstMatrix, secondMatrix);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
package com.baeldung.matrices.benchmark;
|
||||
|
||||
import org.openjdk.jmh.annotations.Param;
|
||||
import org.openjdk.jmh.annotations.Scope;
|
||||
import org.openjdk.jmh.annotations.Setup;
|
||||
import org.openjdk.jmh.annotations.State;
|
||||
import org.openjdk.jmh.infra.BenchmarkParams;
|
||||
|
||||
import java.util.Random;
|
||||
import java.util.stream.DoubleStream;
|
||||
|
||||
@State(Scope.Benchmark)
|
||||
public class BigMatrixProvider {
|
||||
@Param({})
|
||||
private int matrixSize;
|
||||
private double[][] firstMatrix;
|
||||
private double[][] secondMatrix;
|
||||
|
||||
public BigMatrixProvider() {}
|
||||
|
||||
@Setup
|
||||
public void setup(BenchmarkParams parameters) {
|
||||
firstMatrix = createMatrix(matrixSize);
|
||||
secondMatrix = createMatrix(matrixSize);
|
||||
}
|
||||
|
||||
private double[][] createMatrix(int matrixSize) {
|
||||
Random random = new Random();
|
||||
|
||||
double[][] result = new double[matrixSize][matrixSize];
|
||||
for (int row = 0; row < result.length; row++) {
|
||||
for (int col = 0; col < result[row].length; col++) {
|
||||
result[row][col] = random.nextDouble();
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public double[][] getFirstMatrix() {
|
||||
return firstMatrix;
|
||||
}
|
||||
|
||||
public double[][] getSecondMatrix() {
|
||||
return secondMatrix;
|
||||
}
|
||||
}
|
|
@ -1,8 +1,9 @@
|
|||
package com.baeldung.matrices;
|
||||
package com.baeldung.matrices.benchmark;
|
||||
|
||||
import cern.colt.matrix.DoubleFactory2D;
|
||||
import cern.colt.matrix.DoubleMatrix2D;
|
||||
import cern.colt.matrix.linalg.Algebra;
|
||||
import com.baeldung.matrices.HomemadeMatrix;
|
||||
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.ejml.simple.SimpleMatrix;
|
||||
|
@ -23,9 +24,10 @@ public class MatrixMultiplicationBenchmarking {
|
|||
public static void main(String[] args) throws Exception {
|
||||
Options opt = new OptionsBuilder()
|
||||
.include(MatrixMultiplicationBenchmarking.class.getSimpleName())
|
||||
.exclude(BigMatrixMultiplicationBenchmarking.class.getSimpleName())
|
||||
.mode(Mode.AverageTime)
|
||||
.forks(2)
|
||||
.warmupIterations(5)
|
||||
.warmupIterations(10)
|
||||
.measurementIterations(10)
|
||||
.timeUnit(TimeUnit.MICROSECONDS)
|
||||
.build();
|
|
@ -1,4 +1,4 @@
|
|||
package com.baeldung.matrices;
|
||||
package com.baeldung.matrices.benchmark;
|
||||
|
||||
import org.openjdk.jmh.annotations.Scope;
|
||||
import org.openjdk.jmh.annotations.State;
|
Loading…
Reference in New Issue