dupirefr/dupire.francois+pro@gmail.com [BAEL-3606] Matrix Multiplication Libraries Comparison (#8298)

* Added benchmarking on larger matrices

* [BAEL-3606] Moved benchmarking to production code

* [BAEL-3606] Added minor fix
This commit is contained in:
François Dupire 2019-12-03 06:53:52 +01:00 committed by maibin
parent bd87e517dc
commit 1b81033840
4 changed files with 172 additions and 3 deletions

View File

@ -0,0 +1,121 @@
package com.baeldung.matrices.benchmark;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import com.baeldung.matrices.HomemadeMatrix;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.ejml.simple.SimpleMatrix;
import org.la4j.Matrix;
import org.la4j.matrix.dense.Basic2DMatrix;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.options.ChainedOptionsBuilder;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
public class BigMatrixMultiplicationBenchmarking {
private static final int DEFAULT_FORKS = 2;
private static final int DEFAULT_WARMUP_ITERATIONS = 5;
private static final int DEFAULT_MEASUREMENT_ITERATIONS = 10;
public static void main(String[] args) throws Exception {
Map<String, String> parameters = parseParameters(args);
ChainedOptionsBuilder builder = new OptionsBuilder()
.include(BigMatrixMultiplicationBenchmarking.class.getSimpleName())
.mode(Mode.AverageTime)
.forks(forks(parameters))
.warmupIterations(warmupIterations(parameters))
.measurementIterations(measurementIterations(parameters))
.timeUnit(TimeUnit.SECONDS);
parameters.forEach(builder::param);
new Runner(builder.build()).run();
}
private static Map<String, String> parseParameters(String[] args) {
return Arrays.stream(args)
.map(arg -> arg.split("="))
.collect(Collectors.toMap(
arg -> arg[0],
arg -> arg[1]
));
}
private static int forks(Map<String, String> parameters) {
String forks = parameters.remove("forks");
return parseOrDefault(forks, DEFAULT_FORKS);
}
private static int warmupIterations(Map<String, String> parameters) {
String warmups = parameters.remove("warmupIterations");
return parseOrDefault(warmups, DEFAULT_WARMUP_ITERATIONS);
}
private static int measurementIterations(Map<String, String> parameters) {
String measurements = parameters.remove("measurementIterations");
return parseOrDefault(measurements, DEFAULT_MEASUREMENT_ITERATIONS);
}
private static int parseOrDefault(String parameter, int defaultValue) {
return parameter != null ? Integer.parseInt(parameter) : defaultValue;
}
@Benchmark
public Object homemadeMatrixMultiplication(BigMatrixProvider matrixProvider) {
return HomemadeMatrix.multiplyMatrices(matrixProvider.getFirstMatrix(), matrixProvider.getSecondMatrix());
}
@Benchmark
public Object ejmlMatrixMultiplication(BigMatrixProvider matrixProvider) {
SimpleMatrix firstMatrix = new SimpleMatrix(matrixProvider.getFirstMatrix());
SimpleMatrix secondMatrix = new SimpleMatrix(matrixProvider.getSecondMatrix());
return firstMatrix.mult(secondMatrix);
}
@Benchmark
public Object apacheCommonsMatrixMultiplication(BigMatrixProvider matrixProvider) {
RealMatrix firstMatrix = new Array2DRowRealMatrix(matrixProvider.getFirstMatrix());
RealMatrix secondMatrix = new Array2DRowRealMatrix(matrixProvider.getSecondMatrix());
return firstMatrix.multiply(secondMatrix);
}
@Benchmark
public Object la4jMatrixMultiplication(BigMatrixProvider matrixProvider) {
Matrix firstMatrix = new Basic2DMatrix(matrixProvider.getFirstMatrix());
Matrix secondMatrix = new Basic2DMatrix(matrixProvider.getSecondMatrix());
return firstMatrix.multiply(secondMatrix);
}
@Benchmark
public Object nd4jMatrixMultiplication(BigMatrixProvider matrixProvider) {
INDArray firstMatrix = Nd4j.create(matrixProvider.getFirstMatrix());
INDArray secondMatrix = Nd4j.create(matrixProvider.getSecondMatrix());
return firstMatrix.mmul(secondMatrix);
}
@Benchmark
public Object coltMatrixMultiplication(BigMatrixProvider matrixProvider) {
DoubleFactory2D doubleFactory2D = DoubleFactory2D.dense;
DoubleMatrix2D firstMatrix = doubleFactory2D.make(matrixProvider.getFirstMatrix());
DoubleMatrix2D secondMatrix = doubleFactory2D.make(matrixProvider.getSecondMatrix());
Algebra algebra = new Algebra();
return algebra.mult(firstMatrix, secondMatrix);
}
}

View File

@ -0,0 +1,46 @@
package com.baeldung.matrices.benchmark;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.infra.BenchmarkParams;
import java.util.Random;
import java.util.stream.DoubleStream;
@State(Scope.Benchmark)
public class BigMatrixProvider {
@Param({})
private int matrixSize;
private double[][] firstMatrix;
private double[][] secondMatrix;
public BigMatrixProvider() {}
@Setup
public void setup(BenchmarkParams parameters) {
firstMatrix = createMatrix(matrixSize);
secondMatrix = createMatrix(matrixSize);
}
private double[][] createMatrix(int matrixSize) {
Random random = new Random();
double[][] result = new double[matrixSize][matrixSize];
for (int row = 0; row < result.length; row++) {
for (int col = 0; col < result[row].length; col++) {
result[row][col] = random.nextDouble();
}
}
return result;
}
public double[][] getFirstMatrix() {
return firstMatrix;
}
public double[][] getSecondMatrix() {
return secondMatrix;
}
}

View File

@ -1,8 +1,9 @@
package com.baeldung.matrices;
package com.baeldung.matrices.benchmark;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import com.baeldung.matrices.HomemadeMatrix;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.ejml.simple.SimpleMatrix;
@ -23,9 +24,10 @@ public class MatrixMultiplicationBenchmarking {
public static void main(String[] args) throws Exception {
Options opt = new OptionsBuilder()
.include(MatrixMultiplicationBenchmarking.class.getSimpleName())
.exclude(BigMatrixMultiplicationBenchmarking.class.getSimpleName())
.mode(Mode.AverageTime)
.forks(2)
.warmupIterations(5)
.warmupIterations(10)
.measurementIterations(10)
.timeUnit(TimeUnit.MICROSECONDS)
.build();

View File

@ -1,4 +1,4 @@
package com.baeldung.matrices;
package com.baeldung.matrices.benchmark;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.State;