diff --git a/src/main/java/org/apache/commons/math4/linear/AnyMatrix.java b/src/main/java/org/apache/commons/math4/linear/AnyMatrix.java index 0e1bb56fa..062ed9105 100644 --- a/src/main/java/org/apache/commons/math4/linear/AnyMatrix.java +++ b/src/main/java/org/apache/commons/math4/linear/AnyMatrix.java @@ -17,6 +17,7 @@ package org.apache.commons.math4.linear; +import org.apache.commons.math4.exception.DimensionMismatchException; /** * Interface defining very basic matrix operations. @@ -46,4 +47,51 @@ public interface AnyMatrix { * @return the number of columns. */ int getColumnDimension(); + + /** + * Checks that this matrix and the {@code other} matrix can be added. + * + * @param other Matrix to be added. + * @return {@code false} if the dimensions do not match. + */ + default boolean canAdd(AnyMatrix other) { + return getRowDimension() == other.getRowDimension() && + getColumnDimension() == other.getColumnDimension(); + } + + /** + * Checks that this matrix and the {@code other} matrix can be added. + * + * @param other Matrix to check. + * @throws IllegalArgumentException if the dimensions do not match. + */ + default void checkAdd(AnyMatrix other) { + if (!canAdd(other)) { + throw new MatrixDimensionMismatchException(getRowDimension(), getColumnDimension(), + other.getRowDimension(), other.getColumnDimension()); + } + } + + /** + * Checks that this matrix can be multiplied by the {@code other} matrix. + * + * @param other Matrix to be added. + * @return {@code false} if the dimensions do not match. + */ + default boolean canMultiply(AnyMatrix other) { + return getColumnDimension() == other.getRowDimension(); + } + + /** + * Checks that this matrix can be multiplied by the {@code other} matrix. + * + * @param other Matrix to check. + * @throws IllegalArgumentException if the dimensions do not match. + */ + default void checkMultiply(AnyMatrix other) { + if (!canMultiply(other)) { + throw new DimensionMismatchException(getColumnDimension(), + other.getRowDimension()); + } + } } diff --git a/src/main/java/org/apache/commons/math4/linear/MatrixUtils.java b/src/main/java/org/apache/commons/math4/linear/MatrixUtils.java index 04daf9432..47d30178c 100644 --- a/src/main/java/org/apache/commons/math4/linear/MatrixUtils.java +++ b/src/main/java/org/apache/commons/math4/linear/MatrixUtils.java @@ -582,13 +582,8 @@ public class MatrixUtils { * @throws MatrixDimensionMismatchException if the matrices are not addition * compatible. */ - public static void checkAdditionCompatible(final AnyMatrix left, final AnyMatrix right) - throws MatrixDimensionMismatchException { - if ((left.getRowDimension() != right.getRowDimension()) || - (left.getColumnDimension() != right.getColumnDimension())) { - throw new MatrixDimensionMismatchException(left.getRowDimension(), left.getColumnDimension(), - right.getRowDimension(), right.getColumnDimension()); - } + public static void checkAdditionCompatible(final AnyMatrix left, final AnyMatrix right) { + left.checkAdd(right); } /** @@ -599,13 +594,8 @@ public class MatrixUtils { * @throws MatrixDimensionMismatchException if the matrices are not addition * compatible. */ - public static void checkSubtractionCompatible(final AnyMatrix left, final AnyMatrix right) - throws MatrixDimensionMismatchException { - if ((left.getRowDimension() != right.getRowDimension()) || - (left.getColumnDimension() != right.getColumnDimension())) { - throw new MatrixDimensionMismatchException(left.getRowDimension(), left.getColumnDimension(), - right.getRowDimension(), right.getColumnDimension()); - } + public static void checkSubtractionCompatible(final AnyMatrix left, final AnyMatrix right) { + left.checkAdd(right); } /** @@ -616,13 +606,8 @@ public class MatrixUtils { * @throws DimensionMismatchException if matrices are not multiplication * compatible. */ - public static void checkMultiplicationCompatible(final AnyMatrix left, final AnyMatrix right) - throws DimensionMismatchException { - - if (left.getColumnDimension() != right.getRowDimension()) { - throw new DimensionMismatchException(left.getColumnDimension(), - right.getRowDimension()); - } + public static void checkMultiplicationCompatible(final AnyMatrix left, final AnyMatrix right) { + left.checkMultiply(right); } /**