diff --git a/java-math/src/main/java/com/baeldung/matrices/benchmark/BigMatrixMultiplicationBenchmarking.java b/java-math/src/main/java/com/baeldung/matrices/benchmark/BigMatrixMultiplicationBenchmarking.java new file mode 100644 index 0000000000..2ed983f733 --- /dev/null +++ b/java-math/src/main/java/com/baeldung/matrices/benchmark/BigMatrixMultiplicationBenchmarking.java @@ -0,0 +1,121 @@ +package com.baeldung.matrices.benchmark; + +import cern.colt.matrix.DoubleFactory2D; +import cern.colt.matrix.DoubleMatrix2D; +import cern.colt.matrix.linalg.Algebra; +import com.baeldung.matrices.HomemadeMatrix; +import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.linear.RealMatrix; +import org.ejml.simple.SimpleMatrix; +import org.la4j.Matrix; +import org.la4j.matrix.dense.Basic2DMatrix; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.options.ChainedOptionsBuilder; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +public class BigMatrixMultiplicationBenchmarking { + private static final int DEFAULT_FORKS = 2; + private static final int DEFAULT_WARMUP_ITERATIONS = 5; + private static final int DEFAULT_MEASUREMENT_ITERATIONS = 10; + + public static void main(String[] args) throws Exception { + Map parameters = parseParameters(args); + + ChainedOptionsBuilder builder = new OptionsBuilder() + .include(BigMatrixMultiplicationBenchmarking.class.getSimpleName()) + .mode(Mode.AverageTime) + .forks(forks(parameters)) + .warmupIterations(warmupIterations(parameters)) + .measurementIterations(measurementIterations(parameters)) + .timeUnit(TimeUnit.SECONDS); + + parameters.forEach(builder::param); + + new Runner(builder.build()).run(); + } + + private static Map parseParameters(String[] args) { + return Arrays.stream(args) + .map(arg -> arg.split("=")) + .collect(Collectors.toMap( + arg -> arg[0], + arg -> arg[1] + )); + } + + private static int forks(Map parameters) { + String forks = parameters.remove("forks"); + return parseOrDefault(forks, DEFAULT_FORKS); + } + + private static int warmupIterations(Map parameters) { + String warmups = parameters.remove("warmupIterations"); + return parseOrDefault(warmups, DEFAULT_WARMUP_ITERATIONS); + } + + private static int measurementIterations(Map parameters) { + String measurements = parameters.remove("measurementIterations"); + return parseOrDefault(measurements, DEFAULT_MEASUREMENT_ITERATIONS); + } + + private static int parseOrDefault(String parameter, int defaultValue) { + return parameter != null ? Integer.parseInt(parameter) : defaultValue; + } + + @Benchmark + public Object homemadeMatrixMultiplication(BigMatrixProvider matrixProvider) { + return HomemadeMatrix.multiplyMatrices(matrixProvider.getFirstMatrix(), matrixProvider.getSecondMatrix()); + } + + @Benchmark + public Object ejmlMatrixMultiplication(BigMatrixProvider matrixProvider) { + SimpleMatrix firstMatrix = new SimpleMatrix(matrixProvider.getFirstMatrix()); + SimpleMatrix secondMatrix = new SimpleMatrix(matrixProvider.getSecondMatrix()); + + return firstMatrix.mult(secondMatrix); + } + + @Benchmark + public Object apacheCommonsMatrixMultiplication(BigMatrixProvider matrixProvider) { + RealMatrix firstMatrix = new Array2DRowRealMatrix(matrixProvider.getFirstMatrix()); + RealMatrix secondMatrix = new Array2DRowRealMatrix(matrixProvider.getSecondMatrix()); + + return firstMatrix.multiply(secondMatrix); + } + + @Benchmark + public Object la4jMatrixMultiplication(BigMatrixProvider matrixProvider) { + Matrix firstMatrix = new Basic2DMatrix(matrixProvider.getFirstMatrix()); + Matrix secondMatrix = new Basic2DMatrix(matrixProvider.getSecondMatrix()); + + return firstMatrix.multiply(secondMatrix); + } + + @Benchmark + public Object nd4jMatrixMultiplication(BigMatrixProvider matrixProvider) { + INDArray firstMatrix = Nd4j.create(matrixProvider.getFirstMatrix()); + INDArray secondMatrix = Nd4j.create(matrixProvider.getSecondMatrix()); + + return firstMatrix.mmul(secondMatrix); + } + + @Benchmark + public Object coltMatrixMultiplication(BigMatrixProvider matrixProvider) { + DoubleFactory2D doubleFactory2D = DoubleFactory2D.dense; + + DoubleMatrix2D firstMatrix = doubleFactory2D.make(matrixProvider.getFirstMatrix()); + DoubleMatrix2D secondMatrix = doubleFactory2D.make(matrixProvider.getSecondMatrix()); + + Algebra algebra = new Algebra(); + return algebra.mult(firstMatrix, secondMatrix); + } +} \ No newline at end of file diff --git a/java-math/src/main/java/com/baeldung/matrices/benchmark/BigMatrixProvider.java b/java-math/src/main/java/com/baeldung/matrices/benchmark/BigMatrixProvider.java new file mode 100644 index 0000000000..d0f8a03fe3 --- /dev/null +++ b/java-math/src/main/java/com/baeldung/matrices/benchmark/BigMatrixProvider.java @@ -0,0 +1,46 @@ +package com.baeldung.matrices.benchmark; + +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.infra.BenchmarkParams; + +import java.util.Random; +import java.util.stream.DoubleStream; + +@State(Scope.Benchmark) +public class BigMatrixProvider { + @Param({}) + private int matrixSize; + private double[][] firstMatrix; + private double[][] secondMatrix; + + public BigMatrixProvider() {} + + @Setup + public void setup(BenchmarkParams parameters) { + firstMatrix = createMatrix(matrixSize); + secondMatrix = createMatrix(matrixSize); + } + + private double[][] createMatrix(int matrixSize) { + Random random = new Random(); + + double[][] result = new double[matrixSize][matrixSize]; + for (int row = 0; row < result.length; row++) { + for (int col = 0; col < result[row].length; col++) { + result[row][col] = random.nextDouble(); + } + } + return result; + } + + public double[][] getFirstMatrix() { + return firstMatrix; + } + + public double[][] getSecondMatrix() { + return secondMatrix; + } +} \ No newline at end of file diff --git a/java-math/src/test/java/com/baeldung/matrices/MatrixMultiplicationBenchmarking.java b/java-math/src/main/java/com/baeldung/matrices/benchmark/MatrixMultiplicationBenchmarking.java similarity index 93% rename from java-math/src/test/java/com/baeldung/matrices/MatrixMultiplicationBenchmarking.java rename to java-math/src/main/java/com/baeldung/matrices/benchmark/MatrixMultiplicationBenchmarking.java index 171a1d28a4..fdb423e8da 100644 --- a/java-math/src/test/java/com/baeldung/matrices/MatrixMultiplicationBenchmarking.java +++ b/java-math/src/main/java/com/baeldung/matrices/benchmark/MatrixMultiplicationBenchmarking.java @@ -1,8 +1,9 @@ -package com.baeldung.matrices; +package com.baeldung.matrices.benchmark; import cern.colt.matrix.DoubleFactory2D; import cern.colt.matrix.DoubleMatrix2D; import cern.colt.matrix.linalg.Algebra; +import com.baeldung.matrices.HomemadeMatrix; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.RealMatrix; import org.ejml.simple.SimpleMatrix; @@ -23,9 +24,10 @@ public class MatrixMultiplicationBenchmarking { public static void main(String[] args) throws Exception { Options opt = new OptionsBuilder() .include(MatrixMultiplicationBenchmarking.class.getSimpleName()) + .exclude(BigMatrixMultiplicationBenchmarking.class.getSimpleName()) .mode(Mode.AverageTime) .forks(2) - .warmupIterations(5) + .warmupIterations(10) .measurementIterations(10) .timeUnit(TimeUnit.MICROSECONDS) .build(); diff --git a/java-math/src/test/java/com/baeldung/matrices/MatrixProvider.java b/java-math/src/main/java/com/baeldung/matrices/benchmark/MatrixProvider.java similarity index 94% rename from java-math/src/test/java/com/baeldung/matrices/MatrixProvider.java rename to java-math/src/main/java/com/baeldung/matrices/benchmark/MatrixProvider.java index 33bd074b6e..d401ba2ab6 100644 --- a/java-math/src/test/java/com/baeldung/matrices/MatrixProvider.java +++ b/java-math/src/main/java/com/baeldung/matrices/benchmark/MatrixProvider.java @@ -1,4 +1,4 @@ -package com.baeldung.matrices; +package com.baeldung.matrices.benchmark; import org.openjdk.jmh.annotations.Scope; import org.openjdk.jmh.annotations.State;