From e01294b1587ed698bea02e9e2b90cda36d0bf28b Mon Sep 17 00:00:00 2001 From: Luc Maisonobe Date: Fri, 13 Feb 2009 14:38:48 +0000 Subject: [PATCH] Failed attempt to implement Strassen matrix multiplication on recursive layout as described in Siddhartha Chatterjee, Alvin R. Lebeck, Praveen K. Patnala and Mithuna Thottehodi paper "Recursive Array Layout and Fast Matrix Multiplication". As of 2009-02-13, this implementation does not work! The padding at left and bottom sides of the matrix should be cleared after some operations like scalerAdd and is not. Also there is a limitation in the multiplication that can only process matrices with sizes similar enough to have the same power of two number of tiles in all three matrices A, B and C such that C = A*B. These parts have not been fixed since the performance gain with respect to DenseRealMatrix are not very important, and the numerical stability is poor. This may well be due to a bad implementation. This code has been put in the experimental directory for the record, putting it into production would require solving all these issues. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@744126 13f79535-47bb-0310-9956-ffa450edef68 --- .../linear/RecursiveLayoutRealMatrix.java | 2078 +++++++++++++++++ .../linear/RecursiveLayoutRealMatrixTest.java | 1242 ++++++++++ 2 files changed, 3320 insertions(+) create mode 100644 src/experimental/org/apache/commons/math/linear/RecursiveLayoutRealMatrix.java create mode 100644 src/experimental/org/apache/commons/math/linear/RecursiveLayoutRealMatrixTest.java diff --git a/src/experimental/org/apache/commons/math/linear/RecursiveLayoutRealMatrix.java b/src/experimental/org/apache/commons/math/linear/RecursiveLayoutRealMatrix.java new file mode 100644 index 000000000..3ba8928e5 --- /dev/null +++ b/src/experimental/org/apache/commons/math/linear/RecursiveLayoutRealMatrix.java @@ -0,0 +1,2078 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.math.linear; + +import java.io.Serializable; + +import org.apache.commons.math.MathRuntimeException; + +/** + * Cache-friendly implementation of RealMatrix using recursive array layouts to store + * the matrix elements. + *

+ * As of 2009-02-13, this implementation does not work! The padding at left and bottom + * sides of the matrix should be cleared after some operations like scalerAdd + * and is not. Also there is a limitation in the multiplication that can only + * process matrices with sizes similar enough to have the same power of two + * number of tiles in all three matrices A, B and C such that C = A*B. These + * parts have not been fixed since the performance gain with respect to + * DenseRealMatrix are not very important, and the numerical stability is not + * good. This may well be due to a bad implementation. This code has been put + * in the experimental part for the record, putting it into production would + * require solving all these issues. + *

+ *

+ * This implementation is based on the 2002 paper: Recursive Array Layouts + * and Fast Matrix Multiplication by Siddhartha Chatterjee, Alvin R. Lebeck, + * Praveen K. Patnala and Mithuna Thottethodi. + *

+ *

+ * The matrix is split into several rectangular tiles. The tiles are laid out using + * a space-filling curve in a 2k×2k square. This + * implementation uses the Gray-Morton layout which starts as follows for a three-level + * recursion (i.e. an 8x8 matrix). The tiles size are adjusted in order to have the + * 2k×2k square. This may require padding at the right and + * bottom sides of the matrix (see above paper for a discussion of this padding feature). + *

+ *
+ *                    |
+ *    00 01 | 06 07   |   24  25 | 30  31
+ *    03 02 | 05 04   |   27  26 | 29  28
+ *    ------+------   |   -------+-------
+ *    12 13 | 10 11   |   20  21 | 18  19
+ *    15 14 | 09 08   |   23  22 | 17  16
+ *                    |
+ * -------------------+--------------------
+ *                    |
+ *    48 49 | 54 55   |   40  41 | 46  47
+ *    51 50 | 53 52   |   43  42 | 45  44
+ *    ------+------   |   -------+-------
+ *    60 61 | 58 59   |   36  37 | 34  35
+ *    63 62 | 57 56   |   39  38 | 33  32
+ *                    |
+ * 
+ * @version $Revision$ $Date$ + * @since 2.0 + */ +public class RecursiveLayoutRealMatrix extends AbstractRealMatrix implements Serializable { + + /** Serializable version identifier */ + private static final long serialVersionUID = 1607919006739190004L; + + /** Maximal allowed tile size in bytes. + *

In order to avoid cache miss during multiplication, + * a suggested value is cache_size/3.

+ */ + private static final int MAX_TILE_SIZE_BYTES = (64 * 1024) / 3; + //private static final int MAX_TILE_SIZE_BYTES = 32; + + /** Storage array for matrix elements. */ + private final double data[]; + + /** Number of rows of the matrix. */ + private final int rows; + + /** Number of columns of the matrix. */ + private final int columns; + + /** Number of terminal tiles along rows and columns (guaranteed to be a power of 2). */ + private final int tileNumber; + + /** Number of rows in each terminal tile. */ + private final int tileSizeRows; + + /** Number of columns in each terminal tile. */ + private final int tileSizeColumns; + + /** + * Create a new matrix with the supplied row and column dimensions. + * + * @param rows the number of rows in the new matrix + * @param columns the number of columns in the new matrix + * @throws IllegalArgumentException if row or column dimension is not + * positive + */ + public RecursiveLayoutRealMatrix(final int rows, final int columns) + throws IllegalArgumentException { + + super(rows, columns); + this.rows = rows; + this.columns = columns; + + // compute optimal layout + tileNumber = tilesNumber(rows, columns); + tileSizeRows = tileSize(rows, tileNumber); + tileSizeColumns = tileSize(columns, tileNumber); + + // create storage array + data = new double[tileNumber * tileNumber * tileSizeRows * tileSizeColumns]; + + } + + /** + * Create a new dense matrix copying entries from raw layout data. + *

The input array must be in raw layout.

+ *

Calling this constructor is equivalent to call: + *

matrix = new RecursiveLayoutRealMatrix(rawData.length, rawData[0].length,
+     *                                             toRecursiveLayout(rawData), false);
+ *

+ * @param rawData data for new matrix, in raw layout + * + * @exception IllegalArgumentException if rawData shape is + * inconsistent with tile layout + * @see #DenseRealMatrix(int, int, double[][], boolean) + */ + public RecursiveLayoutRealMatrix(final double[][] rawData) + throws IllegalArgumentException { + this(rawData.length, rawData[0].length, toRecursiveLayout(rawData), false); + } + + /** + * Create a new dense matrix copying entries from recursive layout data. + *

The input array must already be in recursive layout.

+ * @param rows the number of rows in the new matrix + * @param columns the number of columns in the new matrix + * @param data data for new matrix, in recursive layout + * @param copyArray if true, the input array will be copied, otherwise + * it will be referenced + * + * @exception IllegalArgumentException if data size is + * inconsistent with matrix size + * @see #toRecursiveLayout(double[][]) + * @see #RecursiveLayoutRealMatrix(double[][]) + */ + public RecursiveLayoutRealMatrix(final int rows, final int columns, + final double[] data, final boolean copyArray) + throws IllegalArgumentException { + + super(rows, columns); + this.rows = rows; + this.columns = columns; + + // compute optimal layout + tileNumber = tilesNumber(rows, columns); + tileSizeRows = tileSize(rows, tileNumber); + tileSizeColumns = tileSize(columns, tileNumber); + + // create storage array + final int expectedLength = tileNumber * tileNumber * tileSizeRows * tileSizeColumns; + if (data.length != expectedLength) { + throw MathRuntimeException.createIllegalArgumentException("wrong array size (got {0}, expected {1})", + new Object[] { + data.length, + expectedLength + }); + } + + if (copyArray) { + // allocate storage array + this.data = data.clone(); + } else { + // reference existing array + this.data = data; + } + + } + + /** + * Convert a data array from raw layout to recursive layout. + *

+ * Raw layout is the straightforward layout where element at row i and + * column j is in array element rawData[i][j]. Recursive layout + * is the layout used in {@link RecursiveLayoutRealMatrix} instances, where the matrix + * is stored in a dimension 1 array using a space-filling curve to spread the matrix + * elements along the array. + *

+ * @param rawData data array in raw layout + * @return a new data array containing the same entries but in recursive layout + * @exception IllegalArgumentException if rawData is not rectangular + * (not all rows have the same length) + * @see #RecursiveLayoutRealMatrix(int, int, double[], boolean) + */ + public static double[] toRecursiveLayout(final double[][] rawData) + throws IllegalArgumentException { + + final int rows = rawData.length; + final int columns = rawData[0].length; + + // compute optimal layout + final int tileNumber = tilesNumber(rows, columns); + final int tileSizeRows = tileSize(rows, tileNumber); + final int tileSizeColumns = tileSize(columns, tileNumber); + + // safety checks + for (int i = 0; i < rawData.length; ++i) { + final int length = rawData[i].length; + if (length != columns) { + throw MathRuntimeException.createIllegalArgumentException( + "some rows have length {0} while others have length {1}", + new Object[] { columns, length }); + } + } + + // convert array row after row + final double[] data = new double[tileNumber * tileNumber * tileSizeRows * tileSizeColumns]; + for (int i = 0; i < rawData.length; ++i) { + final int iTile = i / tileSizeRows; + final double[] rawDataI = rawData[i]; + for (int jTile = 0; jTile < tileNumber; ++jTile) { + final int tileStart = tileIndex(iTile, jTile) * tileSizeRows * tileSizeColumns; + final int dataStart = tileStart + (i - iTile * tileSizeRows) * tileSizeColumns; + final int jStart = jTile * tileSizeColumns; + if (jStart < columns) { + final int jEnd = Math.min(jStart + tileSizeColumns, columns); + System.arraycopy(rawDataI, jStart, data, dataStart, jEnd - jStart); + } + } + } + + return data; + + } + + /** {@inheritDoc} */ + public RealMatrix createMatrix(final int rowDimension, final int columnDimension) + throws IllegalArgumentException { + return new RecursiveLayoutRealMatrix(rowDimension, columnDimension); + } + + /** {@inheritDoc} */ + public RealMatrix copy() { + return new RecursiveLayoutRealMatrix(rows, columns, data, true); + } + + /** {@inheritDoc} */ + public RealMatrix add(final RealMatrix m) + throws IllegalArgumentException { + try { + return add((RecursiveLayoutRealMatrix) m); + } catch (ClassCastException cce) { + + // safety check + checkAdditionCompatible(m); + + final RecursiveLayoutRealMatrix out = new RecursiveLayoutRealMatrix(rows, columns); + + // perform addition tile-wise, to ensure good cache behavior + for (int index = 0; index < tileNumber * tileNumber; ++index) { + + // perform addition on the current tile + final int tileStart = index * tileSizeRows * tileSizeColumns; + final long indices = tilesIndices(index); + final int iTile = (int) (indices >> 32); + final int jTile = (int) (indices & 0xffffffff); + final int pStart = iTile * tileSizeRows; + final int pEnd = Math.min(pStart + tileSizeRows, rows); + final int qStart = jTile * tileSizeColumns; + final int qEnd = Math.min(qStart + tileSizeColumns, columns); + for (int p = pStart; p < pEnd; ++p) { + final int kStart = tileStart + (p - pStart) * tileSizeColumns; + for (int q = qStart, k = kStart; q < qEnd; ++q, ++k) { + out.data[k] = data[k] + m.getEntry(p, q); + } + } + + } + + return out; + + } + } + + /** + * Compute the sum of this and m. + * + * @param m matrix to be added + * @return this + m + * @throws IllegalArgumentException if m is not the same size as this + */ + public RecursiveLayoutRealMatrix add(final RecursiveLayoutRealMatrix m) + throws IllegalArgumentException { + + // safety check + checkAdditionCompatible(m); + + final RecursiveLayoutRealMatrix out = new RecursiveLayoutRealMatrix(rows, columns); + + // streamlined addition + for (int i = 0; i < data.length; ++i) { + out.data[i] = data[i] + m.data[i]; + } + + return out; + + } + + /** {@inheritDoc} */ + public RealMatrix subtract(final RealMatrix m) + throws IllegalArgumentException { + try { + return subtract((RecursiveLayoutRealMatrix) m); + } catch (ClassCastException cce) { + + // safety check + checkSubtractionCompatible(m); + + final RecursiveLayoutRealMatrix out = new RecursiveLayoutRealMatrix(rows, columns); + + // perform subtraction tile-wise, to ensure good cache behavior + for (int index = 0; index < tileNumber * tileNumber; ++index) { + + // perform addition on the current tile + final int tileStart = index * tileSizeRows * tileSizeColumns; + final long indices = tilesIndices(index); + final int iTile = (int) (indices >> 32); + final int jTile = (int) (indices & 0xffffffff); + final int pStart = iTile * tileSizeRows; + final int pEnd = Math.min(pStart + tileSizeRows, rows); + final int qStart = jTile * tileSizeColumns; + final int qEnd = Math.min(qStart + tileSizeColumns, columns); + for (int p = pStart; p < pEnd; ++p) { + final int kStart = tileStart + (p - pStart) * tileSizeColumns; + for (int q = qStart, k = kStart; q < qEnd; ++q, ++k) { + out.data[k] = data[k] - m.getEntry(p, q); + } + } + + } + + return out; + + } + } + + /** + * Compute this minus m. + * + * @param m matrix to be subtracted + * @return this - m + * @throws IllegalArgumentException if m is not the same size as this + */ + public RecursiveLayoutRealMatrix subtract(final RecursiveLayoutRealMatrix m) + throws IllegalArgumentException { + + // safety check + checkSubtractionCompatible(m); + + final RecursiveLayoutRealMatrix out = new RecursiveLayoutRealMatrix(rows, columns); + + // streamlined subtraction + for (int i = 0; i < data.length; ++i) { + out.data[i] = data[i] - m.data[i]; + } + + return out; + + } + + /** {@inheritDoc} */ + public RealMatrix scalarAdd(final double d) + throws IllegalArgumentException { + + final RecursiveLayoutRealMatrix out = new RecursiveLayoutRealMatrix(rows, columns); + + // streamlined addition + for (int i = 0; i < data.length; ++i) { + out.data[i] = data[i] + d; + } + + return out; + + } + + /** {@inheritDoc} */ + public RealMatrix scalarMultiply(final double d) + throws IllegalArgumentException { + + final RecursiveLayoutRealMatrix out = new RecursiveLayoutRealMatrix(rows, columns); + + // streamlined multiplication + for (int i = 0; i < data.length; ++i) { + out.data[i] = data[i] * d; + } + + return out; + + } + + /** {@inheritDoc} */ + public RealMatrix multiply(final RealMatrix m) + throws IllegalArgumentException { + try { + return multiply((RecursiveLayoutRealMatrix) m); + } catch (ClassCastException cce) { + + // safety check + checkMultiplicationCompatible(m); + + final RecursiveLayoutRealMatrix out = new RecursiveLayoutRealMatrix(rows, m.getColumnDimension()); + + // perform multiplication tile-wise, to ensure good cache behavior + for (int index = 0; index < out.tileNumber * out.tileNumber; ++index) { + final int tileStart = index * out.tileSizeRows * out.tileSizeColumns; + final long indices = tilesIndices(index); + final int iTile = (int) (indices >> 32); + final int jTile = (int) (indices & 0xffffffff); + final int iStart = iTile * out.tileSizeRows; + final int iEnd = Math.min(iStart + out.tileSizeRows, out.rows); + final int jStart = jTile * out.tileSizeColumns; + final int jEnd = Math.min(jStart + out.tileSizeColumns, out.columns); + + // perform multiplication for current tile + for (int kTile = 0; kTile < tileNumber; ++kTile) { + final int kTileStart = tileIndex(iTile, kTile) * tileSizeRows * tileSizeColumns; + for (int i = iStart, lStart = kTileStart, oStart = tileStart; + i < iEnd; + ++i, lStart += tileSizeColumns, oStart += out.tileSizeColumns) { + final int lEnd = Math.min(lStart + tileSizeColumns, columns); + for (int j = jStart, o = oStart; j < jEnd; ++j, ++o) { + double sum = 0; + for (int l = lStart, k = kTile * tileSizeColumns; l < lEnd; ++l, ++k) { + sum += data[l] * m.getEntry(k, j); + } + out.data[o] += sum; + } + } + } + } + + return out; + + } + } + + /** + * Returns the result of postmultiplying this by m. + *

The Strassen matrix multiplication method is used here. This + * method computes C = A × B recursively by splitting all matrices + * in four quadrants and computing:

+ *
+     * P1 = (A1,1 + A2,2) × (B1,1 + B2,2)
+     * P2 = (A2,1 + A2,2) × (B1,1)
+     * P3 = (A1,1) × (B1,2 - B2,2)
+     * P4 = (A2,2) × (B2,1 - B1,1)
+     * P5 = (A1,1 + A1,2) × B2,2
+     * P6 = (A2,1 - A1,1) × (B1,1 + B1,2)
+     * P7 = (A1,2 - A2,2) × (B2,1 + B2,2)
+     *
+     * C1,1 = P1 + P4 - P5 + P7
+     * C1,2 = P3 + P5
+     * C2,1 = P2 + P4
+     * C2,2 = P1 + P3 - P2 + P6
+     * 
+ *

+ * This implementation is based on the 2002 paper: Recursive Array Layouts + * and Fast Matrix Multiplication by Siddhartha Chatterjee, Alvin R. Lebeck, + * Praveen K. Patnala and Mithuna Thottethodi. + *

+ * + * @param m matrix to postmultiply by + * @return this * m + * @throws IllegalArgumentException + * if columnDimension(this) != rowDimension(m) + */ + public RecursiveLayoutRealMatrix multiply(RecursiveLayoutRealMatrix m) + throws IllegalArgumentException { + + // safety check + checkMultiplicationCompatible(m); + + final RecursiveLayoutRealMatrix out = new RecursiveLayoutRealMatrix(rows, m.columns); + if ((tileNumber != m.tileNumber) || (tileNumber != out.tileNumber)) { + // TODO get rid of this test + throw new RuntimeException("multiplication " + rows + "x" + columns + " * " + + m.rows + "x" + m.columns + " -> left matrix: " + tileNumber + + " tiles, right matrix: " + m.tileNumber + " tiles, result matrix " + + out.tileNumber + " tiles"); + } + strassenMultiply(data, 0, true, m.data, 0, true, out.data, 0, tileNumber, + tileSizeRows, m.tileSizeColumns, tileSizeColumns); + + return out; + + } + + /** + * Perform recursive multiplication using Strassen's algorithm. + * @param a left term of multiplication + * @param aStart start index in a + * @param aDirect direct/reversed orientation flag for a + * @param b right term of multiplication + * @param bStart start index in b + * @param bDirect direct/reversed orientation flag for b + * @param result result array (will have same orientation as b) + * @param resultStart start index in result + * @param nTiles number of elements to add + * @param bsRows number of rows in result tiles + * @param bsColumns number of columns in result tiles + * @param bsMultiplicands number of rows/columns in multiplicands + */ + private static void strassenMultiply(final double[] a, final int aStart, final boolean aDirect, + final double[] b, final int bStart, final boolean bDirect, + final double[] result, final int resultStart, final int nTiles, + final int bsRows, final int bsColumns, final int bsMultiplicands) { + if (nTiles == 1) { + // leaf recursion tile: perform traditional multiplication + final int bsColumns2 = 2 * bsColumns; + final int bsColumns3 = 3 * bsColumns; + final int bsColumns4 = 4 * bsColumns; + for (int i = 0; i < bsRows; ++i) { + for (int j = 0; j < bsColumns; ++j) { + double sum = 0; + int k = 0; + int aK = aStart + i * bsMultiplicands; + int bK = bStart + j; + while (k < bsMultiplicands - 3) { + sum += a[aK] * b[bK] + + a[aK + 1] * b[bK + bsColumns] + + a[aK + 2] * b[bK + bsColumns2] + + a[aK + 3] * b[bK + bsColumns3]; + k += 4; + aK += 4; + bK += bsColumns4; + } + while (k < bsMultiplicands) { + sum += a[aK] * b[bK]; + k += 1; + aK += 1; + bK += bsColumns; + } + result[resultStart + i * bsColumns + j] = sum; + } + } + } else { + // regular recursion node: use recursive Strassen implementation + final int n2 = nTiles / 2; + final int aQuadrantSize = bsRows * n2 * bsMultiplicands * n2; + final int bQuadrantSize = bsMultiplicands * n2 * bsColumns * n2; + final int cQuadrantSize = bsRows * n2 * bsColumns * n2; + final double[] sA = new double[aQuadrantSize]; + final double[] sB = new double[bQuadrantSize]; + final boolean nonLeafQuadrants = n2 > 1; + + // identify A quadrants start indices + final int a11Start, a12Start, a21Start, a22Start; + if (aDirect) { + a11Start = aStart; + a12Start = aStart + aQuadrantSize; + a21Start = aStart + 3 * aQuadrantSize; + a22Start = aStart + 2 * aQuadrantSize; + } else { + a11Start = aStart + 2 * aQuadrantSize; + a12Start = aStart + 3 * aQuadrantSize; + a21Start = aStart + aQuadrantSize; + a22Start = aStart; + } + + // identify B and C quadrants start indices + // (C is constructed with the same orientation as B) + final int b11Start, b12Start, b21Start, b22Start; + final int c11Start, c12Start, c21Start, c22Start; + if (bDirect) { + b11Start = bStart; + b12Start = bStart + bQuadrantSize; + b21Start = bStart + 3 * bQuadrantSize; + b22Start = bStart + 2 * bQuadrantSize; + c11Start = resultStart; + c12Start = resultStart + cQuadrantSize; + c21Start = resultStart + 3 * cQuadrantSize; + c22Start = resultStart + 2 * cQuadrantSize; + } else { + b11Start = bStart + 2 * bQuadrantSize; + b12Start = bStart + 3 * bQuadrantSize; + b21Start = bStart + bQuadrantSize; + b22Start = bStart; + c11Start = resultStart + 2 * cQuadrantSize; + c12Start = resultStart + 3 * cQuadrantSize; + c21Start = resultStart + cQuadrantSize; + c22Start = resultStart; + } + + // optimal order for cache efficiency: P3, P6, P2, P1, P5, P7, P4 + + // P3 = (A11)(B12 - B22) + // C12 = P3 + ... + tilesSubtract(b, b12Start, false, b, b22Start, false, sB, 0, + bQuadrantSize, nonLeafQuadrants); + strassenMultiply(a, a11Start, true, sB, 0, false, result, c12Start, + n2, bsRows, bsColumns, bsMultiplicands); + + // P6 = (A21 - A11)(B11 + B12) + // C22 = P3 + P6 + ... + final double[] p67 = new double[cQuadrantSize]; + tilesSubtract(a, a21Start, true, a, a11Start, true, sA, 0, + aQuadrantSize, nonLeafQuadrants); + tilesAdd(b, b11Start, true, b, b12Start, false, sB, 0, + bQuadrantSize, nonLeafQuadrants); + strassenMultiply(sA, 0, true, sB, 0, true, p67, 0, + n2, bsRows, bsColumns, bsMultiplicands); + tilesAdd(result, c12Start, false, p67, 0, true, result, c22Start, + cQuadrantSize, nonLeafQuadrants); + + // P2 = (A21 + A22)(B11) + // C21 = P2 + ... + // C22 = P3 + P6 - P2 + ... + tilesAdd(a, a21Start, true, a, a22Start, false, sA, 0, + aQuadrantSize, nonLeafQuadrants); + strassenMultiply(sA, 0, true, b, b11Start, true, result, c21Start, + n2, bsRows, bsColumns, bsMultiplicands); + tilesSelfSubtract(result, c22Start, false, result, c21Start, true, + cQuadrantSize, nonLeafQuadrants); + + // P1 = (A11 + A22)(B11 + B22) + // C11 = P1 + ... + // C22 = P3 + P6 - P2 + P1 + tilesAdd(a, a11Start, true, a, a22Start, false, sA, 0, + aQuadrantSize, nonLeafQuadrants); + tilesAdd(b, b11Start, true, b, b22Start, false, sB, 0, + bQuadrantSize, nonLeafQuadrants); + strassenMultiply(sA, 0, true, sB, 0, true, result, c11Start, + n2, bsRows, bsColumns, bsMultiplicands); + tilesSelfAdd(result, c22Start, false, result, c11Start, true, + cQuadrantSize, nonLeafQuadrants); + + // P5 = (A11 + A12)B22 + // beware: there is a sign error here in Chatterjee et al. paper + // in figure 1, table b they subtract A12 from A11 instead of adding it + // C12 = P3 + P5 + // C11 = P1 - P5 + ... + final double[] p45 = new double[cQuadrantSize]; + tilesAdd(a, a11Start, true, a, a12Start, false, sA, 0, + aQuadrantSize, nonLeafQuadrants); + strassenMultiply(sA, 0, true, b, b22Start, false, p45, 0, + n2, bsRows, bsColumns, bsMultiplicands); + tilesSelfAdd(result, c12Start, false, p45, 0, false, + cQuadrantSize, nonLeafQuadrants); + tilesSelfSubtract(result, c11Start, true, p45, 0, false, + cQuadrantSize, nonLeafQuadrants); + + // P7 = (A12 - A22)(B21 + B22) + // C11 = P1 - P5 + P7 + ... + tilesSubtract(a, a12Start, false, a, a22Start, false, sA, 0, + aQuadrantSize, nonLeafQuadrants); + tilesAdd(b, b21Start, true, b, b22Start, false, sB, 0, + bQuadrantSize, nonLeafQuadrants); + strassenMultiply(sA, 0, false, sB, 0, true, p67, 0, + n2, bsRows, bsColumns, bsMultiplicands); + tilesSelfAdd(result, c11Start, true, p67, 0, true, + cQuadrantSize, nonLeafQuadrants); + + // P4 = (A22)(B21 - B11) + // C11 = P1 - P5 + P7 + P4 + // C21 = P2 + P4 + tilesSubtract(b, b21Start, true, b, b11Start, true, sB, 0, + bQuadrantSize, nonLeafQuadrants); + strassenMultiply(a, a22Start, false, sB, 0, true, p45, 0, + n2, bsRows, bsColumns, bsMultiplicands); + tilesSelfAdd(result, c11Start, true, p45, 0, true, + cQuadrantSize, nonLeafQuadrants); + tilesSelfAdd(result, c21Start, true, p45, 0, true, + cQuadrantSize, nonLeafQuadrants); + + } + } + + /** + * Perform an addition on a few tiles in arrays. + * @param a left term of addition + * @param aStart start index in a + * @param aDirect direct/reversed orientation flag for a + * @param b right term of addition + * @param bStart start index in b + * @param bDirect direct/reversed orientation flag for b + * @param result result array (will have same orientation as a) + * @param resultStart start index in result + * @param n number of elements to add + * @param nonLeafQuadrants if true the quadrant can be further decomposed + */ + private static void tilesAdd(final double[] a, final int aStart, final boolean aDirect, + final double[] b, final int bStart, final boolean bDirect, + final double[] result, final int resultStart, + final int n, final boolean nonLeafQuadrants) { + if ((aDirect ^ bDirect) & nonLeafQuadrants) { + // a and b have different orientations + // perform addition in two half + final int n2 = n / 2; + addLoop(a, aStart, b, bStart + n2, result, resultStart, n2); + addLoop(a, aStart + n2, b, bStart, result, resultStart + n2, n2); + } else { + // a and b have same orientations + // perform addition in one loop + addLoop(a, aStart, b, bStart, result, resultStart, n); + } + } + + /** + * Perform an addition loop. + * @param a left term of addition + * @param aStart start index in a + * @param b right term of addition + * @param bStart start index in b + * @param result result array (will have same orientation as a) + * @param resultStart start index in result + * @param n number of elements to add + */ + private static void addLoop(final double[] a, final int aStart, + final double[] b, final int bStart, + final double[] result, final int resultStart, + final int n) { + int i = 0; + while (i < n - 3) { + final int r0 = resultStart + i; + final int a0 = aStart + i; + final int b0 = bStart + i; + result[r0] = a[a0] + b[b0]; + result[r0 + 1] = a[a0 + 1] + b[b0 + 1]; + result[r0 + 2] = a[a0 + 2] + b[b0 + 2]; + result[r0 + 3] = a[a0 + 3] + b[b0 + 3]; + i += 4; + } + while (i < n) { + result[resultStart + i] = a[aStart + i] + b[bStart + i]; + ++i; + } + } + + /** + * Perform a subtraction on a few tiles in arrays. + * @param a left term of subtraction + * @param aStart start index in a + * @param aDirect direct/reversed orientation flag for a + * @param b right term of subtraction + * @param bStart start index in b + * @param bDirect direct/reversed orientation flag for b + * @param result result array (will have same orientation as a) + * @param resultStart start index in result + * @param n number of elements to subtract + * @param nonLeafQuadrants if true the quadrant can be further decomposed + */ + private static void tilesSubtract(final double[] a, final int aStart, final boolean aDirect, + final double[] b, final int bStart, final boolean bDirect, + final double[] result, final int resultStart, + final int n, final boolean nonLeafQuadrants) { + if ((aDirect ^ bDirect) & nonLeafQuadrants) { + // a and b have different orientations + // perform subtraction in two half + final int n2 = n / 2; + subtractLoop(a, aStart, b, bStart + n2, result, resultStart, n2); + subtractLoop(a, aStart + n2, b, bStart, result, resultStart + n2, n2); + } else { + // a and b have same orientations + // perform subtraction in one loop + subtractLoop(a, aStart, b, bStart, result, resultStart, n); + } + } + + /** + * Perform a subtraction loop. + * @param a left term of subtraction + * @param aStart start index in a + * @param b right term of subtraction + * @param bStart start index in b + * @param result result array (will have same orientation as a) + * @param resultStart start index in result + * @param n number of elements to subtract + */ + private static void subtractLoop(final double[] a, final int aStart, + final double[] b, final int bStart, + final double[] result, final int resultStart, + final int n) { + int i = 0; + while (i < n - 3) { + final int r0 = resultStart + i; + final int a0 = aStart + i; + final int b0 = bStart + i; + result[r0] = a[a0] - b[b0]; + result[r0 + 1] = a[a0 + 1] - b[b0 + 1]; + result[r0 + 2] = a[a0 + 2] - b[b0 + 2]; + result[r0 + 3] = a[a0 + 3] - b[b0 + 3]; + i += 4; + } + while (i < n) { + result[resultStart + i] = a[aStart + i] - b[bStart + i]; + ++i; + } + } + + /** + * Perform a self-addition on a few tiles in arrays. + * @param a left term of addition (will be overwritten with result) + * @param aStart start index in a + * @param aDirect direct/reversed orientation flag for a + * @param b right term of addition + * @param bStart start index in b + * @param bDirect direct/reversed orientation flag for b + * @param n number of elements to add + * @param nonLeafQuadrants if true the quadrant can be further decomposed + */ + private static void tilesSelfAdd(final double[] a, final int aStart, final boolean aDirect, + final double[] b, final int bStart, final boolean bDirect, + final int n, final boolean nonLeafQuadrants) { + if ((aDirect ^ bDirect) & nonLeafQuadrants) { + // a and b have different orientations + // perform addition in two half + final int n2 = n / 2; + selfAddLoop(a, aStart, b, bStart + n2, n2); + selfAddLoop(a, aStart + n2, b, bStart, n2); + } else { + // a and b have same orientations + // perform addition in one loop + selfAddLoop(a, aStart, b, bStart, n); + } + } + + /** + * Perform a self-addition loop. + * @param a left term of addition (will be overwritten with result) + * @param aStart start index in a + * @param b right term of addition + * @param bStart start index in b + * @param n number of elements to add + */ + private static void selfAddLoop(final double[] a, final int aStart, + final double[] b, final int bStart, + final int n) { + int i = 0; + while (i < n - 3) { + final int a0 = aStart + i; + final int b0 = bStart + i; + a[a0] += b[b0]; + a[a0 + 1] += b[b0 + 1]; + a[a0 + 2] += b[b0 + 2]; + a[a0 + 3] += b[b0 + 3]; + i += 4; + } + while (i < n) { + a[aStart + i] += b[bStart + i]; + ++i; + } + } + + /** + * Perform a self-subtraction on a few tiles in arrays. + * @param a left term of subtraction (will be overwritten with result) + * @param aStart start index in a + * @param aDirect direct/reversed orientation flag for a + * @param b right term of subtraction + * @param bStart start index in b + * @param bDirect direct/reversed orientation flag for b + * @param n number of elements to subtract + * @param nonLeafQuadrants if true the quadrant can be further decomposed + */ + private static void tilesSelfSubtract(final double[] a, final int aStart, final boolean aDirect, + final double[] b, final int bStart, final boolean bDirect, + final int n, final boolean nonLeafQuadrants) { + if ((aDirect ^ bDirect) & nonLeafQuadrants) { + // a and b have different orientations + // perform subtraction in two half + final int n2 = n / 2; + selfSubtractLoop(a, aStart, b, bStart + n2, n2); + selfSubtractLoop(a, aStart + n2, b, bStart, n2); + } else { + // a and b have same orientations + // perform subtraction in one loop + selfSubtractLoop(a, aStart, b, bStart, n); + } + } + + /** + * Perform a self-subtraction loop. + * @param a left term of subtraction (will be overwritten with result) + * @param aStart start index in a + * @param b right term of subtraction + * @param bStart start index in b + * @param n number of elements to subtract + */ + private static void selfSubtractLoop(final double[] a, final int aStart, + final double[] b, final int bStart, + final int n) { + int i = 0; + while (i < n - 3) { + final int a0 = aStart + i; + final int b0 = bStart + i; + a[a0] -= b[b0]; + a[a0 + 1] -= b[b0 + 1]; + a[a0 + 2] -= b[b0 + 2]; + a[a0 + 3] -= b[b0 + 3]; + i += 4; + } + while (i < n) { + a[aStart + i] -= b[bStart + i]; + ++i; + } + } + + /** {@inheritDoc} */ + public double[][] getData() { + + final double[][] out = new double[rows][columns]; + + // perform extraction tile-wise, to ensure good cache behavior + for (int index = 0; index < tileNumber * tileNumber; ++index) { + + // perform extraction on the current tile + final int tileStart = index * tileSizeRows * tileSizeColumns; + final long indices = tilesIndices(index); + final int iTile = (int) (indices >> 32); + final int jTile = (int) (indices & 0xffffffff); + final int pStart = iTile * tileSizeRows; + final int qStart = jTile * tileSizeColumns; + if (pStart < rows && qStart < columns) { + final int pEnd = Math.min(pStart + tileSizeRows, rows); + final int qEnd = Math.min(qStart + tileSizeColumns, columns); + int tileRowStart = tileStart; + for (int p = pStart; p < pEnd; ++p) { + System.arraycopy(data, tileRowStart, out[p], qStart, qEnd - qStart); + tileRowStart += tileSizeColumns; + } + } + + } + + return out; + + } + + /** {@inheritDoc} */ + public double getFrobeniusNorm() { + double sum2 = 0; + for (final double entry : data) { + sum2 += entry * entry; + } + return Math.sqrt(sum2); + } + + /** {@inheritDoc} */ + public RealMatrix getSubMatrix(final int startRow, final int endRow, + final int startColumn, final int endColumn) + throws MatrixIndexException { + + // safety checks + checkSubMatrixIndex(startRow, endRow, startColumn, endColumn); + + // create the output matrix + final RecursiveLayoutRealMatrix out = + new RecursiveLayoutRealMatrix(endRow - startRow + 1, endColumn - startColumn + 1); + + // perform extraction tile-wise, to ensure good cache behavior + for (int iTile = 0; iTile < out.tileNumber; ++iTile) { + final int iStart = startRow + iTile * out.tileSizeRows; + final int iEnd = Math.min(startRow + Math.min((iTile + 1) * out.tileSizeRows, out.rows), + endRow + 1); + for (int jTile = 0; jTile < out.tileNumber; ++jTile) { + final int jStart = startColumn + jTile * out.tileSizeColumns; + final int jEnd = Math.min(startColumn + Math.min((jTile + 1) * out.tileSizeColumns, out.columns), + endColumn + 1); + + // the current output tile may expand on more than one instance tile + for (int pTile = iStart / tileSizeRows; pTile * tileSizeRows < iEnd; ++pTile) { + final int p0 = pTile * tileSizeRows; + final int pStart = Math.max(p0, iStart); + final int pEnd = Math.min(Math.min(p0 + tileSizeRows, endRow + 1), iEnd); + for (int qTile = jStart / tileSizeColumns; qTile * tileSizeColumns < jEnd; ++qTile) { + final int q0 = qTile * tileSizeColumns; + final int qStart = Math.max(q0, jStart); + final int qEnd = Math.min(Math.min(q0 + tileSizeColumns, endColumn + 1), jEnd); + + // copy the overlapping part of instance and output tiles + int outIndex = tileIndex(iTile, jTile) * out.tileSizeRows * out.tileSizeColumns + + (pStart - iStart) * out.tileSizeColumns + (qStart - jStart); + int index = tileIndex(pTile, qTile) * tileSizeRows * tileSizeColumns + + (pStart - p0) * tileSizeColumns + (qStart - q0); + for (int p = pStart; p < pEnd; ++p) { + System.arraycopy(data, index, out.data, outIndex, qEnd - qStart); + outIndex += out.tileSizeColumns; + index += tileSizeColumns; + } + + + } + } + + } + } + + return out; + + } + + /** {@inheritDoc} */ + public void setSubMatrix(final double[][] subMatrix, final int row, final int column) + throws MatrixIndexException { + + // safety checks + final int refLength = subMatrix[0].length; + if (refLength < 1) { + throw MathRuntimeException.createIllegalArgumentException("matrix must have at least one column", + null); + } + final int endRow = row + subMatrix.length - 1; + final int endColumn = column + refLength - 1; + checkSubMatrixIndex(row, endRow, column, endColumn); + for (final double[] subRow : subMatrix) { + if (subRow.length != refLength) { + throw MathRuntimeException.createIllegalArgumentException("some rows have length {0} while others have length {1}", + new Object[] { + refLength, subRow.length + }); + } + } + + // compute tiles bounds + final int tileStartRow = row / tileSizeRows; + final int tileEndRow = (endRow + tileSizeRows) / tileSizeRows; + final int tileStartColumn = column / tileSizeColumns; + final int tileEndColumn = (endColumn + tileSizeColumns) / tileSizeColumns; + + // perform copy tile-wise, to ensure good cache behavior + for (int iTile = tileStartRow; iTile < tileEndRow; ++iTile) { + final int firstRow = iTile * tileSizeRows; + final int iStart = Math.max(row, firstRow); + final int iEnd = Math.min(endRow + 1, firstRow + tileSizeRows); + + for (int jTile = tileStartColumn; jTile < tileEndColumn; ++jTile) { + final int firstColumn = jTile * tileSizeColumns; + final int jStart = Math.max(column, firstColumn); + final int jEnd = Math.min(endColumn + 1, firstColumn + tileSizeColumns); + final int jLength = jEnd - jStart; + final int tileStart = tileIndex(iTile, jTile) * tileSizeRows * tileSizeColumns; + + // handle one tile, row by row + for (int i = iStart; i < iEnd; ++i) { + System.arraycopy(subMatrix[i - row], jStart - column, + data, tileStart + (i - firstRow) * tileSizeColumns + (jStart - firstColumn), + jLength); + } + + } + } + } + + /** {@inheritDoc} */ + public RealMatrix getRowMatrix(final int row) + throws MatrixIndexException { + + checkRowIndex(row); + final RecursiveLayoutRealMatrix out = new RecursiveLayoutRealMatrix(1, columns); + + // a row matrix has always only one large tile, + // because a single row cannot be split into 2^k tiles + // perform copy tile-wise, to ensure good cache behavior + final int iTile = row / tileSizeRows; + final int rowOffset = row - iTile * tileSizeRows; + int outIndex = 0; + for (int jTile = 0; jTile < tileNumber; ++jTile) { + final int kStart = tileIndex(iTile, jTile) * tileSizeRows * tileSizeColumns + + rowOffset * tileSizeColumns; + final int length = Math.min(outIndex + tileSizeColumns, columns) - outIndex; + System.arraycopy(data, kStart, out.data, outIndex, length); + outIndex += length; + } + + return out; + + } + + /** {@inheritDoc} */ + public void setRowMatrix(final int row, final RealMatrix matrix) + throws MatrixIndexException, InvalidMatrixException { + try { + setRowMatrix(row, (RecursiveLayoutRealMatrix) matrix); + } catch (ClassCastException cce) { + super.setRowMatrix(row, matrix); + } + } + + /** + * Sets the entries in row number row + * as a row matrix. Row indices start at 0. + * + * @param row the row to be set + * @param matrix row matrix (must have one row and the same number of columns + * as the instance) + * @throws MatrixIndexException if the specified row index is invalid + * @throws InvalidMatrixException if the matrix dimensions do not match one + * instance row + */ + public void setRowMatrix(final int row, final RecursiveLayoutRealMatrix matrix) + throws MatrixIndexException, InvalidMatrixException { + + checkRowIndex(row); + final int nCols = getColumnDimension(); + if ((matrix.getRowDimension() != 1) || + (matrix.getColumnDimension() != nCols)) { + throw new InvalidMatrixException("dimensions mismatch: got {0}x{1} but expected {2}x{3}", + new Object[] { + matrix.getRowDimension(), + matrix.getColumnDimension(), + 1, nCols + }); + } + + // a row matrix has always only one large tile, + // because a single row cannot be split into 2^k tiles + // perform copy tile-wise, to ensure good cache behavior + final int iTile = row / tileSizeRows; + final int rowOffset = row - iTile * tileSizeRows; + int outIndex = 0; + for (int jTile = 0; jTile < tileNumber; ++jTile) { + final int kStart = tileIndex(iTile, jTile) * tileSizeRows * tileSizeColumns + + rowOffset * tileSizeColumns; + final int length = Math.min(outIndex + tileSizeColumns, columns) - outIndex; + System.arraycopy(matrix.data, outIndex, data, kStart, length); + outIndex += length; + } + + } + + /** {@inheritDoc} */ + public RealMatrix getColumnMatrix(final int column) + throws MatrixIndexException { + + checkColumnIndex(column); + final RecursiveLayoutRealMatrix out = new RecursiveLayoutRealMatrix(rows, 1); + + // a column matrix has always only one large tile, + // because a single column cannot be split into 2^k tiles + // perform copy tile-wise, to ensure good cache behavior + final int jTile = column / tileSizeColumns; + final int columnOffset = column - jTile * tileSizeColumns; + for (int iTile = 0; iTile < tileNumber; ++iTile) { + final int pStart = iTile * tileSizeRows; + final int pEnd = Math.min(pStart + tileSizeRows, rows); + final int kStart = tileIndex(iTile, jTile) * tileSizeRows * tileSizeColumns + + columnOffset; + for (int p = pStart, k = kStart; p < pEnd; ++p, k += tileSizeColumns) { + out.data[p] = data[k]; + } + } + + return out; + + } + + /** {@inheritDoc} */ + public void setColumnMatrix(final int column, final RealMatrix matrix) + throws MatrixIndexException, InvalidMatrixException { + try { + setColumnMatrix(column, (RecursiveLayoutRealMatrix) matrix); + } catch (ClassCastException cce) { + super.setColumnMatrix(column, matrix); + } + } + + /** + * Sets the entries in column number column + * as a column matrix. Column indices start at 0. + * + * @param column the column to be set + * @param matrix column matrix (must have one column and the same number of rows + * as the instance) + * @throws MatrixIndexException if the specified column index is invalid + * @throws InvalidMatrixException if the matrix dimensions do not match one + * instance column + */ + void setColumnMatrix(final int column, final RecursiveLayoutRealMatrix matrix) + throws MatrixIndexException, InvalidMatrixException { + + checkColumnIndex(column); + final int nRows = getRowDimension(); + if ((matrix.getRowDimension() != nRows) || + (matrix.getColumnDimension() != 1)) { + throw new InvalidMatrixException("dimensions mismatch: got {0}x{1} but expected {2}x{3}", + new Object[] { + matrix.getRowDimension(), + matrix.getColumnDimension(), + nRows, 1 + }); + } + + // a column matrix has always only one large tile, + // because a single column cannot be split into 2^k tiles + // perform copy tile-wise, to ensure good cache behavior + final int jTile = column / tileSizeColumns; + final int columnOffset = column - jTile * tileSizeColumns; + for (int iTile = 0; iTile < tileNumber; ++iTile) { + final int pStart = iTile * tileSizeRows; + final int pEnd = Math.min(pStart + tileSizeRows, rows); + final int kStart = tileIndex(iTile, jTile) * tileSizeRows * tileSizeColumns + + columnOffset; + for (int p = pStart, k = kStart; p < pEnd; ++p, k += tileSizeColumns) { + data[k] = matrix.data[p]; + } + } + + } + + /** {@inheritDoc} */ + public void setRowVector(final int row, final RealVector vector) + throws MatrixIndexException, InvalidMatrixException { + try { + setRow(row, ((RealVectorImpl) vector).getDataRef()); + } catch (ClassCastException cce) { + checkRowIndex(row); + if (vector.getDimension() != columns) { + throw new InvalidMatrixException("dimensions mismatch: got {0}x{1} but expected {2}x{3}", + new Object[] { + 1, vector.getDimension(), + 1, columns + }); + } + + // perform copy tile-wise, to ensure good cache behavior + final int iTile = row / tileSizeRows; + final int rowOffset = row - iTile * tileSizeRows; + int outIndex = 0; + for (int jTile = 0; jTile < tileNumber; ++jTile) { + final int kStart = tileIndex(iTile, jTile) * tileSizeRows * tileSizeColumns + + rowOffset * tileSizeColumns; + final int length = Math.min(outIndex + tileSizeColumns, columns) - outIndex; + for (int l = 0; l < length; ++l) { + data[kStart + l] = vector.getEntry(outIndex + l); + } + outIndex += length; + } + } + } + + /** {@inheritDoc} */ + public void setColumnVector(final int column, final RealVector vector) + throws MatrixIndexException, InvalidMatrixException { + try { + setColumn(column, ((RealVectorImpl) vector).getDataRef()); + } catch (ClassCastException cce) { + checkColumnIndex(column); + if (vector.getDimension() != rows) { + throw new InvalidMatrixException("dimensions mismatch: got {0}x{1} but expected {2}x{3}", + new Object[] { + vector.getDimension(), 1, + rows, 1 + }); + } + + // perform copy tile-wise, to ensure good cache behavior + final int jTile = column / tileSizeColumns; + final int columnOffset = column - jTile * tileSizeColumns; + for (int iTile = 0; iTile < tileNumber; ++iTile) { + final int pStart = iTile * tileSizeRows; + final int pEnd = Math.min(pStart + tileSizeRows, rows); + final int kStart = tileIndex(iTile, jTile) * tileSizeRows * tileSizeColumns + + columnOffset; + for (int p = pStart, k = kStart; p < pEnd; ++p, k += tileSizeColumns) { + data[k] = vector.getEntry(p); + } + } + } + } + + /** {@inheritDoc} */ + public double[] getRow(final int row) + throws MatrixIndexException { + + checkRowIndex(row); + final double[] out = new double[columns]; + + // perform copy tile-wise, to ensure good cache behavior + final int iTile = row / tileSizeRows; + final int rowOffset = row - iTile * tileSizeRows; + int outIndex = 0; + for (int jTile = 0; jTile < tileNumber; ++jTile) { + final int kStart = tileIndex(iTile, jTile) * tileSizeRows * tileSizeColumns + + rowOffset * tileSizeColumns; + final int length = Math.min(outIndex + tileSizeColumns, columns) - outIndex; + System.arraycopy(data, kStart, out, outIndex, length); + outIndex += length; + } + + return out; + + } + + /** {@inheritDoc} */ + public void setRow(final int row, final double[] array) + throws MatrixIndexException, InvalidMatrixException { + + checkRowIndex(row); + if (array.length != columns) { + throw new InvalidMatrixException("dimensions mismatch: got {0}x{1} but expected {2}x{3}", + new Object[] { + 1, array.length, + 1, columns + }); + } + + // perform copy tile-wise, to ensure good cache behavior + final int iTile = row / tileSizeRows; + final int rowOffset = row - iTile * tileSizeRows; + int outIndex = 0; + for (int jTile = 0; jTile < tileNumber; ++jTile) { + final int kStart = tileIndex(iTile, jTile) * tileSizeRows * tileSizeColumns + + rowOffset * tileSizeColumns; + final int length = Math.min(outIndex + tileSizeColumns, columns) - outIndex; + System.arraycopy(array, outIndex, data, kStart, length); + outIndex += length; + } + + } + + /** {@inheritDoc} */ + public double[] getColumn(final int column) + throws MatrixIndexException { + + checkColumnIndex(column); + final double[] out = new double[rows]; + + // perform copy tile-wise, to ensure good cache behavior + final int jTile = column / tileSizeColumns; + final int columnOffset = column - jTile * tileSizeColumns; + for (int iTile = 0; iTile < tileNumber; ++iTile) { + final int pStart = iTile * tileSizeRows; + final int pEnd = Math.min(pStart + tileSizeRows, rows); + final int kStart = tileIndex(iTile, jTile) * tileSizeRows * tileSizeColumns + + columnOffset; + for (int p = pStart, k = kStart; p < pEnd; ++p, k += tileSizeColumns) { + out[p] = data[k]; + } + } + + return out; + + } + + /** {@inheritDoc} */ + public void setColumn(final int column, final double[] array) + throws MatrixIndexException, InvalidMatrixException { + + checkColumnIndex(column); + if (array.length != rows) { + throw new InvalidMatrixException("dimensions mismatch: got {0}x{1} but expected {2}x{3}", + new Object[] { + array.length, 1, + rows, 1 + }); + } + + // perform copy tile-wise, to ensure good cache behavior + final int jTile = column / tileSizeColumns; + final int columnOffset = column - jTile * tileSizeColumns; + for (int iTile = 0; iTile < tileNumber; ++iTile) { + final int pStart = iTile * tileSizeRows; + final int pEnd = Math.min(pStart + tileSizeRows, rows); + final int kStart = tileIndex(iTile, jTile) * tileSizeRows * tileSizeColumns + + columnOffset; + for (int p = pStart, k = kStart; p < pEnd; ++p, k += tileSizeColumns) { + data[k] = array[p]; + } + } + + } + + /** {@inheritDoc} */ + public double getEntry(final int row, final int column) + throws MatrixIndexException { + if ((row < 0) || (row >= rows) || (column < 0) || (column >= columns)) { + throw new MatrixIndexException("no entry at indices ({0}, {1}) in a {2}x{3} matrix", + new Object[] { + row, column, + getRowDimension(), getColumnDimension() + }); + } + return data[index(row, column)]; + } + + /** {@inheritDoc} */ + public void setEntry(final int row, final int column, final double value) + throws MatrixIndexException { + if ((row < 0) || (row >= rows) || (column < 0) || (column >= columns)) { + throw new MatrixIndexException("no entry at indices ({0}, {1}) in a {2}x{3} matrix", + new Object[] { + row, column, + getRowDimension(), getColumnDimension() + }); + } + data[index(row, column)] = value; + } + + /** {@inheritDoc} */ + public void addToEntry(final int row, final int column, final double increment) + throws MatrixIndexException { + if ((row < 0) || (row >= rows) || (column < 0) || (column >= columns)) { + throw new MatrixIndexException("no entry at indices ({0}, {1}) in a {2}x{3} matrix", + new Object[] { + row, column, + getRowDimension(), getColumnDimension() + }); + } + data[index(row, column)] += increment; + } + + /** {@inheritDoc} */ + public void multiplyEntry(final int row, final int column, final double factor) + throws MatrixIndexException { + if ((row < 0) || (row >= rows) || (column < 0) || (column >= columns)) { + throw new MatrixIndexException("no entry at indices ({0}, {1}) in a {2}x{3} matrix", + new Object[] { + row, column, + getRowDimension(), getColumnDimension() + }); + } + data[index(row, column)] *= factor; + } + + /** {@inheritDoc} */ + public RealMatrix transpose() { + + final RecursiveLayoutRealMatrix out = new RecursiveLayoutRealMatrix(columns, rows); + + // perform transpose tile-wise, to ensure good cache behavior + for (int index = 0; index < tileNumber * tileNumber; ++index) { + final int tileStart = index * tileSizeRows * tileSizeColumns; + final long indices = tilesIndices(index); + final int outJTile = (int) (indices >> 32); // iTile in the instance + final int outITile = (int) (indices & 0xffffffff); // jTile in the instance + final int outIndex = tileIndex(outITile, outJTile); + final int outTileStart = outIndex * tileSizeRows * tileSizeColumns; + + // transpose current tile + final int outPStart = outITile * tileSizeColumns; + final int outPEnd = Math.min(outPStart + tileSizeColumns, columns); + final int outQStart = outJTile * tileSizeRows; + final int outQEnd = Math.min(outQStart + tileSizeRows, rows); + for (int outP = outPStart; outP < outPEnd; ++outP) { + final int dP = outP - outPStart; + int k = outTileStart + dP * tileSizeRows; + int l = tileStart + dP; + for (int outQ = outQStart; outQ < outQEnd; ++outQ) { + out.data[k++] = data[l]; + l+= tileSizeColumns; + } + } + + } + + return out; + + } + + /** {@inheritDoc} */ + public int getRowDimension() { + return rows; + } + + /** {@inheritDoc} */ + public int getColumnDimension() { + return columns; + } + + /** {@inheritDoc} */ + public double[] operate(final double[] v) + throws IllegalArgumentException { + + if (v.length != columns) { + throw MathRuntimeException.createIllegalArgumentException("vector length mismatch:" + + " got {0} but expected {1}", + new Object[] { + v.length, columns + }); + } + final double[] out = new double[rows]; + + // perform multiplication tile-wise, to ensure good cache behavior + for (int index = 0; index < tileNumber * tileNumber; ++index) { + final int tileStart = index * tileSizeRows * tileSizeColumns; + final long indices = tilesIndices(index); + final int iTile = (int) (indices >> 32); + final int jTile = (int) (indices & 0xffffffff); + final int pStart = iTile * tileSizeRows; + final int pEnd = Math.min(pStart + tileSizeRows, rows); + final int qStart = jTile * tileSizeColumns; + final int qEnd = Math.min(qStart + tileSizeColumns, columns); + for (int p = pStart, k = tileStart; p < pEnd; ++p) { + double sum = 0; + int q = qStart; + while (q < qEnd - 3) { + sum += data[k] * v[q] + + data[k + 1] * v[q + 1] + + data[k + 2] * v[q + 2] + + data[k + 3] * v[q + 3]; + k += 4; + q += 4; + } + while (q < qEnd) { + sum += data[k++] * v[q++]; + } + out[p] += sum; + } + } + + return out; + + } + + /** {@inheritDoc} */ + public double[] preMultiply(final double[] v) + throws IllegalArgumentException { + + if (v.length != rows) { + throw MathRuntimeException.createIllegalArgumentException("vector length mismatch:" + + " got {0} but expected {1}", + new Object[] { + v.length, rows + }); + } + final double[] out = new double[columns]; + + final int offset1 = tileSizeColumns; + final int offset2 = offset1 + offset1; + final int offset3 = offset2 + offset1; + final int offset4 = offset3 + offset1; + + // perform multiplication tile-wise, to ensure good cache behavior + for (int index = 0; index < tileNumber * tileNumber; ++index) { + final int tileStart = index * tileSizeRows * tileSizeColumns; + final long indices = tilesIndices(index); + final int iTile = (int) (indices >> 32); + final int jTile = (int) (indices & 0xffffffff); + final int pStart = iTile * tileSizeRows; + final int pEnd = Math.min(pStart + tileSizeRows, rows); + final int qStart = jTile * tileSizeColumns; + final int qEnd = Math.min(qStart + tileSizeColumns, columns); + for (int q = qStart; q < qEnd; ++q) { + int k = tileStart + q - qStart; + double sum = 0; + int p = pStart; + while (p < pEnd - 3) { + sum += data[k] * v[p] + + data[k + offset1] * v[p + 1] + + data[k + offset2] * v[p + 2] + + data[k + offset3] * v[p + 3]; + k += offset4; + p += 4; + } + while (p < pEnd) { + sum += data[k] * v[p++]; + k += offset1; + } + out[q] += sum; + } + } + + return out; + + } + + /** {@inheritDoc} */ + public double walkInRowOrder(final RealMatrixChangingVisitor visitor) + throws MatrixVisitorException { + visitor.start(rows, columns, 0, rows - 1, 0, columns - 1); + for (int iTile = 0; iTile < tileNumber; ++iTile) { + final int pStart = iTile * tileSizeRows; + final int pEnd = Math.min(pStart + tileSizeRows, rows); + for (int p = pStart; p < pEnd; ++p) { + for (int jTile = 0; jTile < tileNumber; ++jTile) { + final int qStart = jTile * tileSizeColumns; + final int qEnd = Math.min(qStart + tileSizeColumns, columns); + final int tileStart = tileIndex(iTile, jTile) * + tileSizeRows * tileSizeColumns; + final int kStart = tileStart + (p - pStart) * tileSizeColumns; + for (int q = qStart, k = kStart; q < qEnd; ++q, ++k) { + data[k] = visitor.visit(p, q, data[k]); + } + } + } + } + return visitor.end(); + } + + /** {@inheritDoc} */ + public double walkInRowOrder(final RealMatrixPreservingVisitor visitor) + throws MatrixVisitorException { + visitor.start(rows, columns, 0, rows - 1, 0, columns - 1); + for (int iTile = 0; iTile < tileNumber; ++iTile) { + final int pStart = iTile * tileSizeRows; + final int pEnd = Math.min(pStart + tileSizeRows, rows); + for (int p = pStart; p < pEnd; ++p) { + for (int jTile = 0; jTile < tileNumber; ++jTile) { + final int qStart = jTile * tileSizeColumns; + final int qEnd = Math.min(qStart + tileSizeColumns, columns); + final int tileStart = tileIndex(iTile, jTile) * + tileSizeRows * tileSizeColumns; + final int kStart = tileStart + (p - pStart) * tileSizeColumns; + for (int q = qStart, k = kStart; q < qEnd; ++q, ++k) { + visitor.visit(p, q, data[k]); + } + } + } + } + return visitor.end(); + } + + /** {@inheritDoc} */ + public double walkInRowOrder(final RealMatrixChangingVisitor visitor, + final int startRow, final int endRow, + final int startColumn, final int endColumn) + throws MatrixIndexException, MatrixVisitorException { + checkSubMatrixIndex(startRow, endRow, startColumn, endColumn); + visitor.start(rows, columns, startRow, endRow, startColumn, endColumn); + for (int iTile = startRow / tileSizeRows; iTile < 1 + endRow / tileSizeRows; ++iTile) { + final int p0 = iTile * tileSizeRows; + final int pStart = Math.max(startRow, p0); + final int pEnd = Math.min((iTile + 1) * tileSizeRows, 1 + endRow); + for (int p = pStart; p < pEnd; ++p) { + for (int jTile = startColumn / tileSizeColumns; jTile < 1 + endColumn / tileSizeColumns; ++jTile) { + final int q0 = jTile * tileSizeColumns; + final int qStart = Math.max(startColumn, q0); + final int qEnd = Math.min((jTile + 1) * tileSizeColumns, 1 + endColumn); + final int tileStart = tileIndex(iTile, jTile) * + tileSizeRows * tileSizeColumns; + final int kStart = tileStart + (p - p0) * tileSizeColumns + (qStart - q0); + for (int q = qStart, k = kStart; q < qEnd; ++q, ++k) { + data[k] = visitor.visit(p, q, data[k]); + } + } + } + } + return visitor.end(); + } + + /** {@inheritDoc} */ + public double walkInRowOrder(final RealMatrixPreservingVisitor visitor, + final int startRow, final int endRow, + final int startColumn, final int endColumn) + throws MatrixIndexException, MatrixVisitorException { + checkSubMatrixIndex(startRow, endRow, startColumn, endColumn); + visitor.start(rows, columns, startRow, endRow, startColumn, endColumn); + for (int iTile = startRow / tileSizeRows; iTile < 1 + endRow / tileSizeRows; ++iTile) { + final int p0 = iTile * tileSizeRows; + final int pStart = Math.max(startRow, p0); + final int pEnd = Math.min((iTile + 1) * tileSizeRows, 1 + endRow); + for (int p = pStart; p < pEnd; ++p) { + for (int jTile = startColumn / tileSizeColumns; jTile < 1 + endColumn / tileSizeColumns; ++jTile) { + final int q0 = jTile * tileSizeColumns; + final int qStart = Math.max(startColumn, q0); + final int qEnd = Math.min((jTile + 1) * tileSizeColumns, 1 + endColumn); + final int tileStart = tileIndex(iTile, jTile) * + tileSizeRows * tileSizeColumns; + final int kStart = tileStart + (p - p0) * tileSizeColumns + (qStart - q0); + for (int q = qStart, k = kStart; q < qEnd; ++q, ++k) { + visitor.visit(p, q, data[k]); + } + } + } + } + return visitor.end(); + } + + /** {@inheritDoc} */ + public double walkInColumnOrder(final RealMatrixChangingVisitor visitor) + throws MatrixVisitorException { + visitor.start(rows, columns, 0, rows - 1, 0, columns - 1); + for (int jTile = 0; jTile < tileNumber; ++jTile) { + final int qStart = jTile * tileSizeColumns; + final int qEnd = Math.min(qStart + tileSizeColumns, columns); + for (int q = qStart; q < qEnd; ++q) { + for (int iTile = 0; iTile < tileNumber; ++iTile) { + final int pStart = iTile * tileSizeRows; + final int pEnd = Math.min(pStart + tileSizeRows, rows); + final int tileStart = tileIndex(iTile, jTile) * + tileSizeRows * tileSizeColumns; + final int kStart = tileStart + (q - qStart); + for (int p = pStart, k = kStart; p < pEnd; ++p, k += tileSizeColumns) { + data[k] = visitor.visit(p, q, data[k]); + } + } + } + } + return visitor.end(); + } + + /** {@inheritDoc} */ + public double walkInColumnOrder(final RealMatrixPreservingVisitor visitor) + throws MatrixVisitorException { + visitor.start(rows, columns, 0, rows - 1, 0, columns - 1); + for (int jTile = 0; jTile < tileNumber; ++jTile) { + final int qStart = jTile * tileSizeColumns; + final int qEnd = Math.min(qStart + tileSizeColumns, columns); + for (int q = qStart; q < qEnd; ++q) { + for (int iTile = 0; iTile < tileNumber; ++iTile) { + final int pStart = iTile * tileSizeRows; + final int pEnd = Math.min(pStart + tileSizeRows, rows); + final int tileStart = tileIndex(iTile, jTile) * + tileSizeRows * tileSizeColumns; + final int kStart = tileStart + (q - qStart); + for (int p = pStart, k = kStart; p < pEnd; ++p, k += tileSizeColumns) { + visitor.visit(p, q, data[k]); + } + } + } + } + return visitor.end(); + } + + /** {@inheritDoc} */ + public double walkInColumnOrder(final RealMatrixChangingVisitor visitor, + final int startRow, final int endRow, + final int startColumn, final int endColumn) + throws MatrixIndexException, MatrixVisitorException { + checkSubMatrixIndex(startRow, endRow, startColumn, endColumn); + visitor.start(getRowDimension(), getColumnDimension(), + startRow, endRow, startColumn, endColumn); + for (int jTile = 0; jTile < tileNumber; ++jTile) { + final int q0 = jTile * tileSizeColumns; + final int qStart = Math.max(startColumn, q0); + final int qEnd = Math.min((jTile + 1) * tileSizeColumns, 1 + endColumn); + for (int q = qStart; q < qEnd; ++q) { + for (int iTile = 0; iTile < tileNumber; ++iTile) { + final int p0 = iTile * tileSizeRows; + final int pStart = Math.max(startRow, p0); + final int pEnd = Math.min((iTile + 1) * tileSizeRows, 1 + endRow); + final int tileStart = tileIndex(iTile, jTile) * + tileSizeRows * tileSizeColumns; + final int kStart = tileStart + (pStart - p0) * tileSizeColumns + (q - q0); + for (int p = pStart, k = kStart; p < pEnd; ++p, k += tileSizeColumns) { + data[k] = visitor.visit(p, q, data[k]); + } + } + } + } + return visitor.end(); + } + + /** {@inheritDoc} */ + public double walkInColumnOrder(final RealMatrixPreservingVisitor visitor, + final int startRow, final int endRow, + final int startColumn, final int endColumn) + throws MatrixIndexException, MatrixVisitorException { + checkSubMatrixIndex(startRow, endRow, startColumn, endColumn); + visitor.start(getRowDimension(), getColumnDimension(), + startRow, endRow, startColumn, endColumn); + for (int jTile = 0; jTile < tileNumber; ++jTile) { + final int q0 = jTile * tileSizeColumns; + final int qStart = Math.max(startColumn, q0); + final int qEnd = Math.min((jTile + 1) * tileSizeColumns, 1 + endColumn); + for (int q = qStart; q < qEnd; ++q) { + for (int iTile = 0; iTile < tileNumber; ++iTile) { + final int p0 = iTile * tileSizeRows; + final int pStart = Math.max(startRow, p0); + final int pEnd = Math.min((iTile + 1) * tileSizeRows, 1 + endRow); + final int tileStart = tileIndex(iTile, jTile) * + tileSizeRows * tileSizeColumns; + final int kStart = tileStart + (pStart - p0) * tileSizeColumns + (q - q0); + for (int p = pStart, k = kStart; p < pEnd; ++p, k += tileSizeColumns) { + visitor.visit(p, q, data[k]); + } + } + } + } + return visitor.end(); + } + + /** {@inheritDoc} */ + public double walkInOptimizedOrder(final RealMatrixChangingVisitor visitor) + throws MatrixVisitorException { + visitor.start(rows, columns, 0, rows - 1, 0, columns - 1); + for (int index = 0; index < tileNumber * tileNumber; ++index) { + final int tileStart = index * tileSizeRows * tileSizeColumns; + final long indices = tilesIndices(index); + final int iTile = (int) (indices >> 32); + final int jTile = (int) (indices & 0xffffffff); + final int pStart = iTile * tileSizeRows; + final int pEnd = Math.min(pStart + tileSizeRows, rows); + final int qStart = jTile * tileSizeColumns; + final int qEnd = Math.min(qStart + tileSizeColumns, columns); + for (int p = pStart; p < pEnd; ++p) { + final int kStart = tileStart + (p - pStart) * tileSizeColumns; + for (int q = qStart, k = kStart; q < qEnd; ++q, ++k) { + data[k] = visitor.visit(p, q, data[k]); + } + } + } + return visitor.end(); + } + + /** {@inheritDoc} */ + public double walkInOptimizedOrder(final RealMatrixPreservingVisitor visitor) + throws MatrixVisitorException { + visitor.start(rows, columns, 0, rows - 1, 0, columns - 1); + for (int index = 0; index < tileNumber * tileNumber; ++index) { + final int tileStart = index * tileSizeRows * tileSizeColumns; + final long indices = tilesIndices(index); + final int iTile = (int) (indices >> 32); + final int jTile = (int) (indices & 0xffffffff); + final int pStart = iTile * tileSizeRows; + final int pEnd = Math.min(pStart + tileSizeRows, rows); + final int qStart = jTile * tileSizeColumns; + final int qEnd = Math.min(qStart + tileSizeColumns, columns); + for (int p = pStart; p < pEnd; ++p) { + final int kStart = tileStart + (p - pStart) * tileSizeColumns; + for (int q = qStart, k = kStart; q < qEnd; ++q, ++k) { + visitor.visit(p, q, data[k]); + } + } + } + return visitor.end(); + } + + /** {@inheritDoc} */ + public double walkInOptimizedOrder(final RealMatrixChangingVisitor visitor, + final int startRow, final int endRow, + final int startColumn, final int endColumn) + throws MatrixIndexException, MatrixVisitorException { + checkSubMatrixIndex(startRow, endRow, startColumn, endColumn); + visitor.start(rows, columns, startRow, endRow, startColumn, endColumn); + for (int index = 0; index < tileNumber * tileNumber; ++index) { + final int tileStart = index * tileSizeRows * tileSizeColumns; + final long indices = tilesIndices(index); + final int iTile = (int) (indices >> 32); + final int jTile = (int) (indices & 0xffffffff); + final int p0 = iTile * tileSizeRows; + final int pStart = Math.max(startRow, p0); + final int pEnd = Math.min((iTile + 1) * tileSizeRows, 1 + endRow); + final int q0 = jTile * tileSizeColumns; + final int qStart = Math.max(startColumn, q0); + final int qEnd = Math.min((jTile + 1) * tileSizeColumns, 1 + endColumn); + for (int p = pStart; p < pEnd; ++p) { + final int kStart = tileStart + (p - p0) * tileSizeColumns + (qStart - q0); + for (int q = qStart, k = kStart; q < qEnd; ++q, ++k) { + data[k] = visitor.visit(p, q, data[k]); + } + } + } + return visitor.end(); + } + + /** {@inheritDoc} */ + public double walkInOptimizedOrder(final RealMatrixPreservingVisitor visitor, + final int startRow, final int endRow, + final int startColumn, final int endColumn) + throws MatrixIndexException, MatrixVisitorException { + checkSubMatrixIndex(startRow, endRow, startColumn, endColumn); + visitor.start(rows, columns, startRow, endRow, startColumn, endColumn); + for (int index = 0; index < tileNumber * tileNumber; ++index) { + final int tileStart = index * tileSizeRows * tileSizeColumns; + final long indices = tilesIndices(index); + final int iTile = (int) (indices >> 32); + final int jTile = (int) (indices & 0xffffffff); + final int p0 = iTile * tileSizeRows; + final int pStart = Math.max(startRow, p0); + final int pEnd = Math.min((iTile + 1) * tileSizeRows, 1 + endRow); + final int q0 = jTile * tileSizeColumns; + final int qStart = Math.max(startColumn, q0); + final int qEnd = Math.min((jTile + 1) * tileSizeColumns, 1 + endColumn); + for (int p = pStart; p < pEnd; ++p) { + final int kStart = tileStart + (p - p0) * tileSizeColumns + (qStart - q0); + for (int q = qStart, k = kStart; q < qEnd; ++q, ++k) { + visitor.visit(p, q, data[k]); + } + } + } + return visitor.end(); + } + + /** + * Get the index of an element. + * @param row row index of the element + * @param column column index of the element + * @return index of the element + */ + private int index(final int row, final int columns) { + final int iTile = row / tileSizeRows; + final int jTile = columns / tileSizeColumns; + final int tileStart = tileIndex(iTile, jTile) * tileSizeRows * tileSizeColumns; + final int indexInTile = (row % tileSizeRows) * tileSizeColumns + + (columns % tileSizeColumns); + return tileStart + indexInTile; + } + + /** + * Get the index of a tile. + * @param iTile row index of the tile + * @param jTile column index of the tile + * @return index of the tile + */ + private static int tileIndex(int iTile, int jTile) { + + // compute n = 2^k such that a nxn square contains the indices + int n = Integer.highestOneBit(Math.max(iTile, jTile)) << 1; + + // start recursion by noting the index is somewhere in the nxn + // square whose lowest index is 0 and which has direct orientation + int lowIndex = 0; + boolean direct = true; + + // the tail-recursion on the square size is replaced by an iteration here + while (n > 1) { + + // reduce square to 4 quadrants + n >>= 1; + final int n2 = n * n; + + // check in which quadrant the element is, + // updating the lowest index of the quadrant and its orientation + if (iTile < n) { + if (jTile < n) { + // the element is in the top-left quadrant + if (!direct) { + lowIndex += 2 * n2; + direct = true; + } + } else { + // the element is in the top-right quadrant + jTile -= n; + if (direct) { + lowIndex += n2; + direct = false; + } else { + lowIndex += 3 * n2; + } + } + } else { + iTile -= n; + if (jTile < n) { + // the element is in the bottom-left quadrant + if (direct) { + lowIndex += 3 * n2; + } else { + lowIndex += n2; + direct = true; + } + } else { + // the element is in the bottom-right quadrant + jTile -= n; + if (direct) { + lowIndex += 2 * n2; + direct = false; + } + } + } + } + + // the lowest index of the remaining 1x1 quadrant is the requested index + return lowIndex; + + } + + /** + * Get the row and column tile indices of a tile. + * @param index index of the tile in the layout + * @return row and column indices packed in one long (row tile index + * in 32 high order bits, column tile index in low order bits) + */ + private static long tilesIndices(int index) { + + // compute n = 2^k such that a nxn square contains the index + int n = Integer.highestOneBit((int) Math.sqrt(index)) << 1; + + // start recursion by noting the index is somewhere in the nxn + // square whose lowest index is 0 and which has direct orientation + int iLow = 0; + int jLow = 0; + boolean direct = true; + + // the tail-recursion on the square size is replaced by an iteration here + while (n > 1) { + + // reduce square to 4 quadrants + n >>= 1; + final int n2 = n * n; + + // check in which quadrant the element is, + // updating the low indices of the quadrant and its orientation + switch (index / n2) { + case 0 : + if (!direct) { + iLow += n; + jLow += n; + } + break; + case 1 : + if (direct) { + jLow += n; + } else { + iLow += n; + } + index -= n2; + direct = !direct; + break; + case 2 : + if (direct) { + iLow += n; + jLow += n; + } + index -= 2 * n2; + direct = !direct; + break; + default : + if (direct) { + iLow += n; + } else { + jLow += n; + } + index -= 3 * n2; + } + + } + + // the lowest indices of the remaining 1x1 quadrant are the requested indices + return (((long) iLow) << 32) | (long) jLow; + + } + + /** + * Compute the power of two number of tiles for a matrix. + * @param rows number of rows + * @param columns number of columns + * @return power of two number of tiles + */ + private static int tilesNumber(final int rows, final int columns) { + + // find the minimal number of tiles, given that one double variable is 8 bytes + final int nbElements = rows * columns; + final int maxElementsPerTile = MAX_TILE_SIZE_BYTES / 8; + final int minTiles = nbElements / maxElementsPerTile; + + // the number of tiles must be a 2^k x 2^k square + int twoK = 1; + for (int nTiles = minTiles; nTiles != 0; nTiles >>= 2) { + twoK <<= 1; + } + + // make sure the tiles have at least one row and one column each + // (this may lead to tile sizes greater than MAX_BLOCK_SIZE_BYTES, + // in degenerate cases like a 3000x1 matrix) + while (twoK > Math.min(rows, columns)) { + twoK >>= 1; + } + + return twoK; + + } + + /** + * Compute optimal tile size for a row or column count. + * @param count row or column count + * @param twoK optimal tile number (must be a power of 2) + * @return optimal tile size + */ + private static int tileSize(final int count, final int twoK) { + return (count + twoK - 1) / twoK; + } + +} diff --git a/src/experimental/org/apache/commons/math/linear/RecursiveLayoutRealMatrixTest.java b/src/experimental/org/apache/commons/math/linear/RecursiveLayoutRealMatrixTest.java new file mode 100644 index 000000000..2e8719097 --- /dev/null +++ b/src/experimental/org/apache/commons/math/linear/RecursiveLayoutRealMatrixTest.java @@ -0,0 +1,1242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math.linear; + +import java.util.Arrays; +import java.util.Random; + +import junit.framework.Test; +import junit.framework.TestCase; +import junit.framework.TestSuite; + +/** + * Test cases for the {@link RecursiveLayoutRealMatrix} class. + * + * @version $Revision$ $Date$ + */ + +public final class RecursiveLayoutRealMatrixTest extends TestCase { + + // 3 x 3 identity matrix + protected double[][] id = { {1d,0d,0d}, {0d,1d,0d}, {0d,0d,1d} }; + + // Test data for group operations + protected double[][] testData = { {1d,2d,3d}, {2d,5d,3d}, {1d,0d,8d} }; + protected double[][] testDataLU = {{2d, 5d, 3d}, {.5d, -2.5d, 6.5d}, {0.5d, 0.2d, .2d}}; + protected double[][] testDataPlus2 = { {3d,4d,5d}, {4d,7d,5d}, {3d,2d,10d} }; + protected double[][] testDataMinus = { {-1d,-2d,-3d}, {-2d,-5d,-3d}, + {-1d,0d,-8d} }; + protected double[] testDataRow1 = {1d,2d,3d}; + protected double[] testDataCol3 = {3d,3d,8d}; + protected double[][] testDataInv = + { {-40d,16d,9d}, {13d,-5d,-3d}, {5d,-2d,-1d} }; + protected double[] preMultTest = {8,12,33}; + protected double[][] testData2 ={ {1d,2d,3d}, {2d,5d,3d}}; + protected double[][] testData2T = { {1d,2d}, {2d,5d}, {3d,3d}}; + protected double[][] testDataPlusInv = + { {-39d,18d,12d}, {15d,0d,0d}, {6d,-2d,7d} }; + + // lu decomposition tests + protected double[][] luData = { {2d,3d,3d}, {0d,5d,7d}, {6d,9d,8d} }; + protected double[][] luDataLUDecomposition = { {6d,9d,8d}, {0d,5d,7d}, + {0.33333333333333,0d,0.33333333333333} }; + + // singular matrices + protected double[][] singular = { {2d,3d}, {2d,3d} }; + protected double[][] bigSingular = {{1d,2d,3d,4d}, {2d,5d,3d,4d}, + {7d,3d,256d,1930d}, {3d,7d,6d,8d}}; // 4th row = 1st + 2nd + protected double[][] detData = { {1d,2d,3d}, {4d,5d,6d}, {7d,8d,10d} }; + protected double[][] detData2 = { {1d, 3d}, {2d, 4d}}; + + // vectors + protected double[] testVector = {1,2,3}; + protected double[] testVector2 = {1,2,3,4}; + + // submatrix accessor tests + protected double[][] subTestData = {{1, 2, 3, 4}, {1.5, 2.5, 3.5, 4.5}, + {2, 4, 6, 8}, {4, 5, 6, 7}}; + // array selections + protected double[][] subRows02Cols13 = { {2, 4}, {4, 8}}; + protected double[][] subRows03Cols12 = { {2, 3}, {5, 6}}; + protected double[][] subRows03Cols123 = { {2, 3, 4} , {5, 6, 7}}; + // effective permutations + protected double[][] subRows20Cols123 = { {4, 6, 8} , {2, 3, 4}}; + protected double[][] subRows31Cols31 = {{7, 5}, {4.5, 2.5}}; + // contiguous ranges + protected double[][] subRows01Cols23 = {{3,4} , {3.5, 4.5}}; + protected double[][] subRows23Cols00 = {{2} , {4}}; + protected double[][] subRows00Cols33 = {{4}}; + // row matrices + protected double[][] subRow0 = {{1,2,3,4}}; + protected double[][] subRow3 = {{4,5,6,7}}; + // column matrices + protected double[][] subColumn1 = {{2}, {2.5}, {4}, {5}}; + protected double[][] subColumn3 = {{4}, {4.5}, {8}, {7}}; + + // tolerances + protected double entryTolerance = 10E-16; + protected double normTolerance = 10E-14; + + public RecursiveLayoutRealMatrixTest(String name) { + super(name); + } + + public void setUp() { + + } + + public static Test suite() { + TestSuite suite = new TestSuite(RecursiveLayoutRealMatrixTest.class); + suite.setName("RecursiveLayoutRealMatrix Tests"); + return suite; + } + + /** test dimensions */ + public void testDimensions() { + RecursiveLayoutRealMatrix m = new RecursiveLayoutRealMatrix(testData); + RecursiveLayoutRealMatrix m2 = new RecursiveLayoutRealMatrix(testData2); + assertEquals("testData row dimension",3,m.getRowDimension()); + assertEquals("testData column dimension",3,m.getColumnDimension()); + assertTrue("testData is square",m.isSquare()); + assertEquals("testData2 row dimension",m2.getRowDimension(),2); + assertEquals("testData2 column dimension",m2.getColumnDimension(),3); + assertTrue("testData2 is not square",!m2.isSquare()); + } + + /** test copy functions */ + public void testCopyFunctions() { + Random r = new Random(66636328996002l); + RecursiveLayoutRealMatrix m1 = createRandomMatrix(r, 47, 83); + RecursiveLayoutRealMatrix m2 = new RecursiveLayoutRealMatrix(m1.getData()); + assertEquals(m1, m2); + RecursiveLayoutRealMatrix m3 = new RecursiveLayoutRealMatrix(testData); + RecursiveLayoutRealMatrix m4 = new RecursiveLayoutRealMatrix(m3.getData()); + assertEquals(m3, m4); + } + + /** test add */ + public void testAdd() { + RecursiveLayoutRealMatrix m = new RecursiveLayoutRealMatrix(testData); + RecursiveLayoutRealMatrix mInv = new RecursiveLayoutRealMatrix(testDataInv); + RealMatrix mPlusMInv = m.add(mInv); + double[][] sumEntries = mPlusMInv.getData(); + for (int row = 0; row < m.getRowDimension(); row++) { + for (int col = 0; col < m.getColumnDimension(); col++) { + assertEquals("sum entry entry", + testDataPlusInv[row][col],sumEntries[row][col], + entryTolerance); + } + } + } + + /** test add failure */ + public void testAddFail() { + RecursiveLayoutRealMatrix m = new RecursiveLayoutRealMatrix(testData); + RecursiveLayoutRealMatrix m2 = new RecursiveLayoutRealMatrix(testData2); + try { + m.add(m2); + fail("IllegalArgumentException expected"); + } catch (IllegalArgumentException ex) { + ; + } + } + + /** test norm */ + public void testNorm() { + RecursiveLayoutRealMatrix m = new RecursiveLayoutRealMatrix(testData); + RecursiveLayoutRealMatrix m2 = new RecursiveLayoutRealMatrix(testData2); + assertEquals("testData norm",14d,m.getNorm(),entryTolerance); + assertEquals("testData2 norm",7d,m2.getNorm(),entryTolerance); + } + + /** test Frobenius norm */ + public void testFrobeniusNorm() { + RecursiveLayoutRealMatrix m = new RecursiveLayoutRealMatrix(testData); + RecursiveLayoutRealMatrix m2 = new RecursiveLayoutRealMatrix(testData2); + assertEquals("testData Frobenius norm", Math.sqrt(117.0), m.getFrobeniusNorm(), entryTolerance); + assertEquals("testData2 Frobenius norm", Math.sqrt(52.0), m2.getFrobeniusNorm(), entryTolerance); + } + + /** test m-n = m + -n */ + public void testPlusMinus() { + RecursiveLayoutRealMatrix m = new RecursiveLayoutRealMatrix(testData); + RecursiveLayoutRealMatrix m2 = new RecursiveLayoutRealMatrix(testDataInv); + assertClose(m.subtract(m2), m2.scalarMultiply(-1d).add(m), entryTolerance); + try { + m.subtract(new RecursiveLayoutRealMatrix(testData2)); + fail("Expecting illegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } + } + + /** test multiply */ + public void testMultiply() { + RecursiveLayoutRealMatrix m = new RecursiveLayoutRealMatrix(testData); + RecursiveLayoutRealMatrix mInv = new RecursiveLayoutRealMatrix(testDataInv); + RecursiveLayoutRealMatrix identity = new RecursiveLayoutRealMatrix(id); + RecursiveLayoutRealMatrix m2 = new RecursiveLayoutRealMatrix(testData2); + assertClose(m.multiply(mInv), identity, entryTolerance); + assertClose(mInv.multiply(m), identity, entryTolerance); + assertClose(m.multiply(identity), m, entryTolerance); + assertClose(identity.multiply(mInv), mInv, entryTolerance); + assertClose(m2.multiply(identity), m2, entryTolerance); + try { + m.multiply(new RecursiveLayoutRealMatrix(bigSingular)); + fail("Expecting illegalArgumentException"); + } catch (IllegalArgumentException ex) { + // expected + } + } + + public void testSeveralBlocks() { + + RealMatrix m = new RecursiveLayoutRealMatrix(35, 71); + for (int i = 0; i < m.getRowDimension(); ++i) { + for (int j = 0; j < m.getColumnDimension(); ++j) { + m.setEntry(i, j, i + j / 1024.0); + } + } + + RealMatrix mT = m.transpose(); + assertEquals(m.getRowDimension(), mT.getColumnDimension()); + assertEquals(m.getColumnDimension(), mT.getRowDimension()); + for (int i = 0; i < mT.getRowDimension(); ++i) { + for (int j = 0; j < mT.getColumnDimension(); ++j) { + assertEquals(m.getEntry(j, i), mT.getEntry(i, j), 0); + } + } + + RealMatrix mPm = m.add(m); + for (int i = 0; i < mPm.getRowDimension(); ++i) { + for (int j = 0; j < mPm.getColumnDimension(); ++j) { + assertEquals(2 * m.getEntry(i, j), mPm.getEntry(i, j), 0); + } + } + + RealMatrix mPmMm = mPm.subtract(m); + for (int i = 0; i < mPmMm.getRowDimension(); ++i) { + for (int j = 0; j < mPmMm.getColumnDimension(); ++j) { + assertEquals(m.getEntry(i, j), mPmMm.getEntry(i, j), 0); + } + } + + RealMatrix mTm = mT.multiply(m); + for (int i = 0; i < mTm.getRowDimension(); ++i) { + for (int j = 0; j < mTm.getColumnDimension(); ++j) { + double sum = 0; + for (int k = 0; k < mT.getColumnDimension(); ++k) { + sum += (k + i / 1024.0) * (k + j / 1024.0); + } + assertEquals(sum, mTm.getEntry(i, j), 0); + } + } + + RealMatrix mmT = m.multiply(mT); + for (int i = 0; i < mmT.getRowDimension(); ++i) { + for (int j = 0; j < mmT.getColumnDimension(); ++j) { + double sum = 0; + for (int k = 0; k < m.getColumnDimension(); ++k) { + sum += (i + k / 1024.0) * (j + k / 1024.0); + } + assertEquals(sum, mmT.getEntry(i, j), 0); + } + } + + RealMatrix sub1 = m.getSubMatrix(2, 9, 5, 20); + for (int i = 0; i < sub1.getRowDimension(); ++i) { + for (int j = 0; j < sub1.getColumnDimension(); ++j) { + assertEquals((i + 2) + (j + 5) / 1024.0, sub1.getEntry(i, j), 0); + } + } + + RealMatrix sub2 = m.getSubMatrix(10, 12, 3, 70); + for (int i = 0; i < sub2.getRowDimension(); ++i) { + for (int j = 0; j < sub2.getColumnDimension(); ++j) { + assertEquals((i + 10) + (j + 3) / 1024.0, sub2.getEntry(i, j), 0); + } + } + + RealMatrix sub3 = m.getSubMatrix(30, 34, 0, 5); + for (int i = 0; i < sub3.getRowDimension(); ++i) { + for (int j = 0; j < sub3.getColumnDimension(); ++j) { + assertEquals((i + 30) + (j + 0) / 1024.0, sub3.getEntry(i, j), 0); + } + } + + RealMatrix sub4 = m.getSubMatrix(30, 32, 62, 65); + for (int i = 0; i < sub4.getRowDimension(); ++i) { + for (int j = 0; j < sub4.getColumnDimension(); ++j) { + assertEquals((i + 30) + (j + 62) / 1024.0, sub4.getEntry(i, j), 0); + } + } + + } + + //Additional Test for RecursiveLayoutRealMatrixTest.testMultiply + + private double[][] d3 = new double[][] {{1,2,3,4},{5,6,7,8}}; + private double[][] d4 = new double[][] {{1},{2},{3},{4}}; + private double[][] d5 = new double[][] {{30},{70}}; + + public void testMultiply2() { + RealMatrix m3 = new RecursiveLayoutRealMatrix(d3); + RealMatrix m4 = new RecursiveLayoutRealMatrix(d4); + RealMatrix m5 = new RecursiveLayoutRealMatrix(d5); + assertClose(m3.multiply(m4), m5, entryTolerance); + } + + /** test trace */ + public void testTrace() { + RealMatrix m = new RecursiveLayoutRealMatrix(id); + assertEquals("identity trace",3d,m.getTrace(),entryTolerance); + m = new RecursiveLayoutRealMatrix(testData2); + try { + m.getTrace(); + fail("Expecting NonSquareMatrixException"); + } catch (NonSquareMatrixException ex) { + ; + } + } + + /** test scalarAdd */ + public void testScalarAdd() { + RealMatrix m = new RecursiveLayoutRealMatrix(testData); + assertClose(new RecursiveLayoutRealMatrix(testDataPlus2), m.scalarAdd(2d), entryTolerance); + } + + /** test operate */ + public void testOperate() { + RealMatrix m = new RecursiveLayoutRealMatrix(id); + assertClose(testVector, m.operate(testVector), entryTolerance); + assertClose(testVector, m.operate(new RealVectorImpl(testVector)).getData(), entryTolerance); + m = new RecursiveLayoutRealMatrix(bigSingular); + try { + m.operate(testVector); + fail("Expecting illegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } + } + + public void testMultiplyMedium() { + RealMatrix m1 = new RecursiveLayoutRealMatrix( + new double[][] { + { 80, 45, 13, 77, -82 }, + { -90, 33, 98, 80, 74 }, + { 24, -37, 36, -8, -69 }, + { -74, 2, 32, -67, -65 }, + { -29, -81, 44, 54, -65 }, + { 17, 58, -36, -98, 25 }, + { 48, -64, -95, -75, 34 } + }); + RealMatrix m2 = new RecursiveLayoutRealMatrix( + new double[][] { + { 81, 58, 70, 18, 5, -57 }, + { -54, 33, 87, 68, -22, 73 }, + { -78, -5, 34, -7, -3, -31 }, + { -16, -82, -68, 7, 10, -47 }, + { 51, 4, 92, 15, 32, -51 } + }); + RealMatrix m1m2 = m1.multiply(m2); + RealMatrix reference = new RecursiveLayoutRealMatrix( + new double[][]{ + { -2378, -582, -2823, 3718, -2483, -1115 }, + { -14222, -10885, 1271, 1608, 1698, -3033 }, + { -2257, 371, -6119, -3427, -1462, -1290 }, + { -10841, 848, -5342, -2864, -3260, 9836 }, + { -5586, -9263, -17233, -6935, -35, -4847 }, + { 3896, 11216, 13976, 4191, -1263, 7712 }, + { 17688, 7433, 2790, -2838, 2271, -2672 } + }); + assertEquals(0, m1m2.subtract(reference).getNorm(), 0.0); + } + + public void testOperateLarge() { + int testBlockSize = 64; + int p = (7 * testBlockSize) / 2; + int q = (5 * testBlockSize) / 2; + int r = 3 * testBlockSize; + Random random = new Random(111007463902334l); + RealMatrix m1 = createRandomMatrix(random, p, q); + RealMatrix m2 = createRandomMatrix(random, q, r); + RealMatrix m1m2 = m1.multiply(m2); + for (int i = 0; i < r; ++i) { + checkArrays(m1m2.getColumn(i), m1.operate(m2.getColumn(i))); + } + } + + public void testOperatePremultiplyLarge() { + int testBlockSize = 64; + int p = (7 * testBlockSize) / 2; + int q = (5 * testBlockSize) / 2; + int r = 3 * testBlockSize; + Random random = new Random(111007463902334l); + RealMatrix m1 = createRandomMatrix(random, p, q); + RealMatrix m2 = createRandomMatrix(random, q, r); + RealMatrix m1m2 = m1.multiply(m2); + for (int i = 0; i < p; ++i) { + checkArrays(m1m2.getRow(i), m2.preMultiply(m1.getRow(i))); + } + } + + /** test issue MATH-209 */ + public void testMath209() { + RealMatrix a = new RecursiveLayoutRealMatrix(new double[][] { + { 1, 2 }, { 3, 4 }, { 5, 6 } + }); + double[] b = a.operate(new double[] { 1, 1 }); + assertEquals(a.getRowDimension(), b.length); + assertEquals( 3.0, b[0], 1.0e-12); + assertEquals( 7.0, b[1], 1.0e-12); + assertEquals(11.0, b[2], 1.0e-12); + } + + /** test transpose */ + public void testTranspose() { + RealMatrix m = new RecursiveLayoutRealMatrix(testData); + RealMatrix mIT = new LUDecompositionImpl(m).getSolver().getInverse().transpose(); + RealMatrix mTI = new LUDecompositionImpl(m.transpose()).getSolver().getInverse(); + assertClose(mIT, mTI, normTolerance); + m = new RecursiveLayoutRealMatrix(testData2); + RealMatrix mt = new RecursiveLayoutRealMatrix(testData2T); + assertClose(mt, m.transpose(), normTolerance); + } + + /** test preMultiply by vector */ + public void testPremultiplyVector() { + RealMatrix m = new RecursiveLayoutRealMatrix(testData); + assertClose(m.preMultiply(testVector), preMultTest, normTolerance); + assertClose(m.preMultiply(new RealVectorImpl(testVector).getData()), + preMultTest, normTolerance); + m = new RecursiveLayoutRealMatrix(bigSingular); + try { + m.preMultiply(testVector); + fail("expecting IllegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } + } + + public void testPremultiply() { + RealMatrix m3 = new RecursiveLayoutRealMatrix(d3); + RealMatrix m4 = new RecursiveLayoutRealMatrix(d4); + RealMatrix m5 = new RecursiveLayoutRealMatrix(d5); + assertClose(m4.preMultiply(m3), m5, entryTolerance); + + RecursiveLayoutRealMatrix m = new RecursiveLayoutRealMatrix(testData); + RecursiveLayoutRealMatrix mInv = new RecursiveLayoutRealMatrix(testDataInv); + RecursiveLayoutRealMatrix identity = new RecursiveLayoutRealMatrix(id); + assertClose(m.preMultiply(mInv), identity, entryTolerance); + assertClose(mInv.preMultiply(m), identity, entryTolerance); + assertClose(m.preMultiply(identity), m, entryTolerance); + assertClose(identity.preMultiply(mInv), mInv, entryTolerance); + try { + m.preMultiply(new RecursiveLayoutRealMatrix(bigSingular)); + fail("Expecting illegalArgumentException"); + } catch (IllegalArgumentException ex) { + ; + } + } + + public void testGetVectors() { + RealMatrix m = new RecursiveLayoutRealMatrix(testData); + assertClose(m.getRow(0), testDataRow1, entryTolerance); + assertClose(m.getColumn(2), testDataCol3, entryTolerance); + try { + m.getRow(10); + fail("expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + ; + } + try { + m.getColumn(-1); + fail("expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + ; + } + } + + public void testGetEntry() { + RealMatrix m = new RecursiveLayoutRealMatrix(testData); + assertEquals("get entry",m.getEntry(0,1),2d,entryTolerance); + try { + m.getEntry(10, 4); + fail ("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + } + + /** test examples in user guide */ + public void testExamples() { + // Create a real matrix with two rows and three columns + double[][] matrixData = { {1d,2d,3d}, {2d,5d,3d}}; + RealMatrix m = new RecursiveLayoutRealMatrix(matrixData); + // One more with three rows, two columns + double[][] matrixData2 = { {1d,2d}, {2d,5d}, {1d, 7d}}; + RealMatrix n = new RecursiveLayoutRealMatrix(matrixData2); + // Now multiply m by n + RealMatrix p = m.multiply(n); + assertEquals(2, p.getRowDimension()); + assertEquals(2, p.getColumnDimension()); + // Invert p + RealMatrix pInverse = new LUDecompositionImpl(p).getSolver().getInverse(); + assertEquals(2, pInverse.getRowDimension()); + assertEquals(2, pInverse.getColumnDimension()); + + // Solve example + double[][] coefficientsData = {{2, 3, -2}, {-1, 7, 6}, {4, -3, -5}}; + RealMatrix coefficients = new RecursiveLayoutRealMatrix(coefficientsData); + double[] constants = {1, -2, 1}; + double[] solution = new LUDecompositionImpl(coefficients).getSolver().solve(constants); + assertEquals(2 * solution[0] + 3 * solution[1] -2 * solution[2], constants[0], 1E-12); + assertEquals(-1 * solution[0] + 7 * solution[1] + 6 * solution[2], constants[1], 1E-12); + assertEquals(4 * solution[0] - 3 * solution[1] -5 * solution[2], constants[2], 1E-12); + + } + + // test submatrix accessors + public void testGetSubMatrix() { + RealMatrix m = new RecursiveLayoutRealMatrix(subTestData); + checkGetSubMatrix(m, subRows23Cols00, 2 , 3 , 0, 0, false); + checkGetSubMatrix(m, subRows00Cols33, 0 , 0 , 3, 3, false); + checkGetSubMatrix(m, subRows01Cols23, 0 , 1 , 2, 3, false); + checkGetSubMatrix(m, subRows02Cols13, new int[] { 0, 2 }, new int[] { 1, 3 }, false); + checkGetSubMatrix(m, subRows03Cols12, new int[] { 0, 3 }, new int[] { 1, 2 }, false); + checkGetSubMatrix(m, subRows03Cols123, new int[] { 0, 3 }, new int[] { 1, 2, 3 }, false); + checkGetSubMatrix(m, subRows20Cols123, new int[] { 2, 0 }, new int[] { 1, 2, 3 }, false); + checkGetSubMatrix(m, subRows31Cols31, new int[] { 3, 1 }, new int[] { 3, 1 }, false); + checkGetSubMatrix(m, subRows31Cols31, new int[] { 3, 1 }, new int[] { 3, 1 }, false); + checkGetSubMatrix(m, null, 1, 0, 2, 4, true); + checkGetSubMatrix(m, null, -1, 1, 2, 2, true); + checkGetSubMatrix(m, null, 1, 0, 2, 2, true); + checkGetSubMatrix(m, null, 1, 0, 2, 4, true); + checkGetSubMatrix(m, null, new int[] {}, new int[] { 0 }, true); + checkGetSubMatrix(m, null, new int[] { 0 }, new int[] { 4 }, true); + } + + private void checkGetSubMatrix(RealMatrix m, double[][] reference, + int startRow, int endRow, int startColumn, int endColumn, + boolean mustFail) { + try { + RealMatrix sub = m.getSubMatrix(startRow, endRow, startColumn, endColumn); + assertEquals(new RecursiveLayoutRealMatrix(reference), sub); + if (mustFail) { + fail("Expecting MatrixIndexException"); + } + } catch (MatrixIndexException e) { + if (!mustFail) { + throw e; + } + } + } + + private void checkGetSubMatrix(RealMatrix m, double[][] reference, + int[] selectedRows, int[] selectedColumns, + boolean mustFail) { + try { + RealMatrix sub = m.getSubMatrix(selectedRows, selectedColumns); + assertEquals(new RecursiveLayoutRealMatrix(reference), sub); + if (mustFail) { + fail("Expecting MatrixIndexException"); + } + } catch (MatrixIndexException e) { + if (!mustFail) { + throw e; + } + } + } + + public void testGetSetMatrixLarge() { + int n = 3 * 64; + RealMatrix m = new RecursiveLayoutRealMatrix(n, n); + RealMatrix sub = new RecursiveLayoutRealMatrix(n - 4, n - 4).scalarAdd(1); + + m.setSubMatrix(sub.getData(), 2, 2); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + if ((i < 2) || (i > n - 3) || (j < 2) || (j > n - 3)) { + assertEquals(0.0, m.getEntry(i, j), 0.0); + } else { + assertEquals(1.0, m.getEntry(i, j), 0.0); + } + } + } + assertEquals(sub, m.getSubMatrix(2, n - 3, 2, n - 3)); + + } + + public void testCopySubMatrix() { + RealMatrix m = new RecursiveLayoutRealMatrix(subTestData); + checkCopy(m, subRows23Cols00, 2 , 3 , 0, 0, false); + checkCopy(m, subRows00Cols33, 0 , 0 , 3, 3, false); + checkCopy(m, subRows01Cols23, 0 , 1 , 2, 3, false); + checkCopy(m, subRows02Cols13, new int[] { 0, 2 }, new int[] { 1, 3 }, false); + checkCopy(m, subRows03Cols12, new int[] { 0, 3 }, new int[] { 1, 2 }, false); + checkCopy(m, subRows03Cols123, new int[] { 0, 3 }, new int[] { 1, 2, 3 }, false); + checkCopy(m, subRows20Cols123, new int[] { 2, 0 }, new int[] { 1, 2, 3 }, false); + checkCopy(m, subRows31Cols31, new int[] { 3, 1 }, new int[] { 3, 1 }, false); + checkCopy(m, subRows31Cols31, new int[] { 3, 1 }, new int[] { 3, 1 }, false); + + checkCopy(m, null, 1, 0, 2, 4, true); + checkCopy(m, null, -1, 1, 2, 2, true); + checkCopy(m, null, 1, 0, 2, 2, true); + checkCopy(m, null, 1, 0, 2, 4, true); + checkCopy(m, null, new int[] {}, new int[] { 0 }, true); + checkCopy(m, null, new int[] { 0 }, new int[] { 4 }, true); + } + + private void checkCopy(RealMatrix m, double[][] reference, + int startRow, int endRow, int startColumn, int endColumn, + boolean mustFail) { + try { + double[][] sub = (reference == null) ? + new double[1][1] : + new double[reference.length][reference[0].length]; + m.copySubMatrix(startRow, endRow, startColumn, endColumn, sub); + assertEquals(new RecursiveLayoutRealMatrix(reference), new RecursiveLayoutRealMatrix(sub)); + if (mustFail) { + fail("Expecting MatrixIndexException"); + } + } catch (MatrixIndexException e) { + if (!mustFail) { + throw e; + } + } + } + + private void checkCopy(RealMatrix m, double[][] reference, + int[] selectedRows, int[] selectedColumns, + boolean mustFail) { + try { + double[][] sub = (reference == null) ? + new double[1][1] : + new double[reference.length][reference[0].length]; + m.copySubMatrix(selectedRows, selectedColumns, sub); + assertEquals(new RecursiveLayoutRealMatrix(reference), new RecursiveLayoutRealMatrix(sub)); + if (mustFail) { + fail("Expecting MatrixIndexException"); + } + } catch (MatrixIndexException e) { + if (!mustFail) { + throw e; + } + } + } + + public void testGetRowMatrix() { + RealMatrix m = new RecursiveLayoutRealMatrix(subTestData); + RealMatrix mRow0 = new RecursiveLayoutRealMatrix(subRow0); + RealMatrix mRow3 = new RecursiveLayoutRealMatrix(subRow3); + assertEquals("Row0", mRow0, m.getRowMatrix(0)); + assertEquals("Row3", mRow3, m.getRowMatrix(3)); + try { + m.getRowMatrix(-1); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + try { + m.getRowMatrix(4); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + } + + public void testSetRowMatrix() { + RealMatrix m = new RecursiveLayoutRealMatrix(subTestData); + RealMatrix mRow3 = new RecursiveLayoutRealMatrix(subRow3); + assertNotSame(mRow3, m.getRowMatrix(0)); + m.setRowMatrix(0, mRow3); + assertEquals(mRow3, m.getRowMatrix(0)); + try { + m.setRowMatrix(-1, mRow3); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + try { + m.setRowMatrix(0, m); + fail("Expecting InvalidMatrixException"); + } catch (InvalidMatrixException ex) { + // expected + } + } + + public void testGetSetRowMatrixLarge() { + int n = 3 * 64; + RealMatrix m = new RecursiveLayoutRealMatrix(n, n); + RealMatrix sub = new RecursiveLayoutRealMatrix(1, n).scalarAdd(1); + + m.setRowMatrix(2, sub); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + if (i != 2) { + assertEquals(0.0, m.getEntry(i, j), 0.0); + } else { + assertEquals(1.0, m.getEntry(i, j), 0.0); + } + } + } + assertEquals(sub, m.getRowMatrix(2)); + + } + + public void testGetColumnMatrix() { + RealMatrix m = new RecursiveLayoutRealMatrix(subTestData); + RealMatrix mColumn1 = new RecursiveLayoutRealMatrix(subColumn1); + RealMatrix mColumn3 = new RecursiveLayoutRealMatrix(subColumn3); + assertEquals(mColumn1, m.getColumnMatrix(1)); + assertEquals(mColumn3, m.getColumnMatrix(3)); + try { + m.getColumnMatrix(-1); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + try { + m.getColumnMatrix(4); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + } + + public void testSetColumnMatrix() { + RealMatrix m = new RecursiveLayoutRealMatrix(subTestData); + RealMatrix mColumn3 = new RecursiveLayoutRealMatrix(subColumn3); + assertNotSame(mColumn3, m.getColumnMatrix(1)); + m.setColumnMatrix(1, mColumn3); + assertEquals(mColumn3, m.getColumnMatrix(1)); + try { + m.setColumnMatrix(-1, mColumn3); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + try { + m.setColumnMatrix(0, m); + fail("Expecting InvalidMatrixException"); + } catch (InvalidMatrixException ex) { + // expected + } + } + + public void testGetSetColumnMatrixLarge() { + int n = 3 * 64; + RealMatrix m = new RecursiveLayoutRealMatrix(n, n); + RealMatrix sub = new RecursiveLayoutRealMatrix(n, 1).scalarAdd(1); + + m.setColumnMatrix(2, sub); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + if (j != 2) { + assertEquals(0.0, m.getEntry(i, j), 0.0); + } else { + assertEquals(1.0, m.getEntry(i, j), 0.0); + } + } + } + assertEquals(sub, m.getColumnMatrix(2)); + + } + + public void testGetRowVector() { + RealMatrix m = new RecursiveLayoutRealMatrix(subTestData); + RealVector mRow0 = new RealVectorImpl(subRow0[0]); + RealVector mRow3 = new RealVectorImpl(subRow3[0]); + assertEquals(mRow0, m.getRowVector(0)); + assertEquals(mRow3, m.getRowVector(3)); + try { + m.getRowVector(-1); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + try { + m.getRowVector(4); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + } + + public void testSetRowVector() { + RealMatrix m = new RecursiveLayoutRealMatrix(subTestData); + RealVector mRow3 = new RealVectorImpl(subRow3[0]); + assertNotSame(mRow3, m.getRowMatrix(0)); + m.setRowVector(0, mRow3); + assertEquals(mRow3, m.getRowVector(0)); + try { + m.setRowVector(-1, mRow3); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + try { + m.setRowVector(0, new RealVectorImpl(5)); + fail("Expecting InvalidMatrixException"); + } catch (InvalidMatrixException ex) { + // expected + } + } + + public void testGetSetRowVectorLarge() { + int n = 3 * 64; + RealMatrix m = new RecursiveLayoutRealMatrix(n, n); + RealVector sub = new RealVectorImpl(n, 1.0); + + m.setRowVector(2, sub); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + if (i != 2) { + assertEquals(0.0, m.getEntry(i, j), 0.0); + } else { + assertEquals(1.0, m.getEntry(i, j), 0.0); + } + } + } + assertEquals(sub, m.getRowVector(2)); + + } + + public void testGetColumnVector() { + RealMatrix m = new RecursiveLayoutRealMatrix(subTestData); + RealVector mColumn1 = columnToVector(subColumn1); + RealVector mColumn3 = columnToVector(subColumn3); + assertEquals(mColumn1, m.getColumnVector(1)); + assertEquals(mColumn3, m.getColumnVector(3)); + try { + m.getColumnVector(-1); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + try { + m.getColumnVector(4); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + } + + public void testSetColumnVector() { + RealMatrix m = new RecursiveLayoutRealMatrix(subTestData); + RealVector mColumn3 = columnToVector(subColumn3); + assertNotSame(mColumn3, m.getColumnVector(1)); + m.setColumnVector(1, mColumn3); + assertEquals(mColumn3, m.getColumnVector(1)); + try { + m.setColumnVector(-1, mColumn3); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + try { + m.setColumnVector(0, new RealVectorImpl(5)); + fail("Expecting InvalidMatrixException"); + } catch (InvalidMatrixException ex) { + // expected + } + } + + public void testGetSetColumnVectorLarge() { + int n = 3 * 64; + RealMatrix m = new RecursiveLayoutRealMatrix(n, n); + RealVector sub = new RealVectorImpl(n, 1.0); + + m.setColumnVector(2, sub); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + if (j != 2) { + assertEquals(0.0, m.getEntry(i, j), 0.0); + } else { + assertEquals(1.0, m.getEntry(i, j), 0.0); + } + } + } + assertEquals(sub, m.getColumnVector(2)); + + } + + private RealVector columnToVector(double[][] column) { + double[] data = new double[column.length]; + for (int i = 0; i < data.length; ++i) { + data[i] = column[i][0]; + } + return new RealVectorImpl(data, false); + } + + public void testGetRow() { + RealMatrix m = new RecursiveLayoutRealMatrix(subTestData); + checkArrays(subRow0[0], m.getRow(0)); + checkArrays(subRow3[0], m.getRow(3)); + try { + m.getRow(-1); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + try { + m.getRow(4); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + } + + public void testSetRow() { + RealMatrix m = new RecursiveLayoutRealMatrix(subTestData); + assertTrue(subRow3[0][0] != m.getRow(0)[0]); + m.setRow(0, subRow3[0]); + checkArrays(subRow3[0], m.getRow(0)); + try { + m.setRow(-1, subRow3[0]); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + try { + m.setRow(0, new double[5]); + fail("Expecting InvalidMatrixException"); + } catch (InvalidMatrixException ex) { + // expected + } + } + + public void testGetSetRowLarge() { + int n = 3 * 64; + RealMatrix m = new RecursiveLayoutRealMatrix(n, n); + double[] sub = new double[n]; + Arrays.fill(sub, 1.0); + + m.setRow(2, sub); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + if (i != 2) { + assertEquals(0.0, m.getEntry(i, j), 0.0); + } else { + assertEquals(1.0, m.getEntry(i, j), 0.0); + } + } + } + checkArrays(sub, m.getRow(2)); + + } + + public void testGetColumn() { + RealMatrix m = new RecursiveLayoutRealMatrix(subTestData); + double[] mColumn1 = columnToArray(subColumn1); + double[] mColumn3 = columnToArray(subColumn3); + checkArrays(mColumn1, m.getColumn(1)); + checkArrays(mColumn3, m.getColumn(3)); + try { + m.getColumn(-1); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + try { + m.getColumn(4); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + } + + public void testSetColumn() { + RealMatrix m = new RecursiveLayoutRealMatrix(subTestData); + double[] mColumn3 = columnToArray(subColumn3); + assertTrue(mColumn3[0] != m.getColumn(1)[0]); + m.setColumn(1, mColumn3); + checkArrays(mColumn3, m.getColumn(1)); + try { + m.setColumn(-1, mColumn3); + fail("Expecting MatrixIndexException"); + } catch (MatrixIndexException ex) { + // expected + } + try { + m.setColumn(0, new double[5]); + fail("Expecting InvalidMatrixException"); + } catch (InvalidMatrixException ex) { + // expected + } + } + + public void testGetSetColumnLarge() { + int n = 3 * 64; + RealMatrix m = new RecursiveLayoutRealMatrix(n, n); + double[] sub = new double[n]; + Arrays.fill(sub, 1.0); + + m.setColumn(2, sub); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + if (j != 2) { + assertEquals(0.0, m.getEntry(i, j), 0.0); + } else { + assertEquals(1.0, m.getEntry(i, j), 0.0); + } + } + } + checkArrays(sub, m.getColumn(2)); + + } + + private double[] columnToArray(double[][] column) { + double[] data = new double[column.length]; + for (int i = 0; i < data.length; ++i) { + data[i] = column[i][0]; + } + return data; + } + + private void checkArrays(double[] expected, double[] actual) { + assertEquals(expected.length, actual.length); + for (int i = 0; i < expected.length; ++i) { + assertEquals(expected[i], actual[i], 1.0e-9 * Math.abs(expected[i])); + } + } + + public void testEqualsAndHashCode() { + RecursiveLayoutRealMatrix m = new RecursiveLayoutRealMatrix(testData); + RecursiveLayoutRealMatrix m1 = (RecursiveLayoutRealMatrix) m.copy(); + RecursiveLayoutRealMatrix mt = (RecursiveLayoutRealMatrix) m.transpose(); + assertTrue(m.hashCode() != mt.hashCode()); + assertEquals(m.hashCode(), m1.hashCode()); + assertEquals(m, m); + assertEquals(m, m1); + assertFalse(m.equals(null)); + assertFalse(m.equals(mt)); + assertFalse(m.equals(new RecursiveLayoutRealMatrix(bigSingular))); + } + + public void testToString() { + RecursiveLayoutRealMatrix m = new RecursiveLayoutRealMatrix(testData); + assertEquals("RecursiveLayoutRealMatrix{{1.0,2.0,3.0},{2.0,5.0,3.0},{1.0,0.0,8.0}}", + m.toString()); + } + + public void testSetSubMatrix() throws Exception { + RecursiveLayoutRealMatrix m = new RecursiveLayoutRealMatrix(testData); + m.setSubMatrix(detData2,1,1); + RealMatrix expected = new RecursiveLayoutRealMatrix + (new double[][] {{1.0,2.0,3.0},{2.0,1.0,3.0},{1.0,2.0,4.0}}); + assertEquals(expected, m); + + m.setSubMatrix(detData2,0,0); + expected = new RecursiveLayoutRealMatrix + (new double[][] {{1.0,3.0,3.0},{2.0,4.0,3.0},{1.0,2.0,4.0}}); + assertEquals(expected, m); + + m.setSubMatrix(testDataPlus2,0,0); + expected = new RecursiveLayoutRealMatrix + (new double[][] {{3.0,4.0,5.0},{4.0,7.0,5.0},{3.0,2.0,10.0}}); + assertEquals(expected, m); + + // javadoc example + RecursiveLayoutRealMatrix matrix = new RecursiveLayoutRealMatrix + (new double[][] {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 0, 1 , 2}}); + matrix.setSubMatrix(new double[][] {{3, 4}, {5, 6}}, 1, 1); + expected = new RecursiveLayoutRealMatrix + (new double[][] {{1, 2, 3, 4}, {5, 3, 4, 8}, {9, 5 ,6, 2}}); + assertEquals(expected, matrix); + + // dimension overflow + try { + m.setSubMatrix(testData,1,1); + fail("expecting MatrixIndexException"); + } catch (MatrixIndexException e) { + // expected + } + // dimension underflow + try { + m.setSubMatrix(testData,-1,1); + fail("expecting MatrixIndexException"); + } catch (MatrixIndexException e) { + // expected + } + try { + m.setSubMatrix(testData,1,-1); + fail("expecting MatrixIndexException"); + } catch (MatrixIndexException e) { + // expected + } + + // null + try { + m.setSubMatrix(null,1,1); + fail("expecting NullPointerException"); + } catch (NullPointerException e) { + // expected + } + + // ragged + try { + m.setSubMatrix(new double[][] {{1}, {2, 3}}, 0, 0); + fail("expecting IllegalArgumentException"); + } catch (IllegalArgumentException e) { + // expected + } + + // empty + try { + m.setSubMatrix(new double[][] {{}}, 0, 0); + fail("expecting IllegalArgumentException"); + } catch (IllegalArgumentException e) { + // expected + } + + } + + public void testWalk() { + int rows = 150; + int columns = 75; + + RealMatrix m = new RecursiveLayoutRealMatrix(rows, columns); + m.walkInRowOrder(new SetVisitor()); + GetVisitor getVisitor = new GetVisitor(); + m.walkInOptimizedOrder(getVisitor); + assertEquals(rows * columns, getVisitor.getCount()); + + m = new RecursiveLayoutRealMatrix(rows, columns); + m.walkInRowOrder(new SetVisitor(), 1, rows - 2, 1, columns - 2); + getVisitor = new GetVisitor(); + m.walkInOptimizedOrder(getVisitor, 1, rows - 2, 1, columns - 2); + assertEquals((rows - 2) * (columns - 2), getVisitor.getCount()); + for (int i = 0; i < rows; ++i) { + assertEquals(0.0, m.getEntry(i, 0), 0); + assertEquals(0.0, m.getEntry(i, columns - 1), 0); + } + for (int j = 0; j < columns; ++j) { + assertEquals(0.0, m.getEntry(0, j), 0); + assertEquals(0.0, m.getEntry(rows - 1, j), 0); + } + + m = new RecursiveLayoutRealMatrix(rows, columns); + m.walkInColumnOrder(new SetVisitor()); + getVisitor = new GetVisitor(); + m.walkInOptimizedOrder(getVisitor); + assertEquals(rows * columns, getVisitor.getCount()); + + m = new RecursiveLayoutRealMatrix(rows, columns); + m.walkInColumnOrder(new SetVisitor(), 1, rows - 2, 1, columns - 2); + getVisitor = new GetVisitor(); + m.walkInOptimizedOrder(getVisitor, 1, rows - 2, 1, columns - 2); + assertEquals((rows - 2) * (columns - 2), getVisitor.getCount()); + for (int i = 0; i < rows; ++i) { + assertEquals(0.0, m.getEntry(i, 0), 0); + assertEquals(0.0, m.getEntry(i, columns - 1), 0); + } + for (int j = 0; j < columns; ++j) { + assertEquals(0.0, m.getEntry(0, j), 0); + assertEquals(0.0, m.getEntry(rows - 1, j), 0); + } + + m = new RecursiveLayoutRealMatrix(rows, columns); + m.walkInOptimizedOrder(new SetVisitor()); + getVisitor = new GetVisitor(); + m.walkInRowOrder(getVisitor); + assertEquals(rows * columns, getVisitor.getCount()); + + m = new RecursiveLayoutRealMatrix(rows, columns); + m.walkInOptimizedOrder(new SetVisitor(), 1, rows - 2, 1, columns - 2); + getVisitor = new GetVisitor(); + m.walkInRowOrder(getVisitor, 1, rows - 2, 1, columns - 2); + assertEquals((rows - 2) * (columns - 2), getVisitor.getCount()); + for (int i = 0; i < rows; ++i) { + assertEquals(0.0, m.getEntry(i, 0), 0); + assertEquals(0.0, m.getEntry(i, columns - 1), 0); + } + for (int j = 0; j < columns; ++j) { + assertEquals(0.0, m.getEntry(0, j), 0); + assertEquals(0.0, m.getEntry(rows - 1, j), 0); + } + + m = new RecursiveLayoutRealMatrix(rows, columns); + m.walkInOptimizedOrder(new SetVisitor()); + getVisitor = new GetVisitor(); + m.walkInColumnOrder(getVisitor); + assertEquals(rows * columns, getVisitor.getCount()); + + m = new RecursiveLayoutRealMatrix(rows, columns); + m.walkInOptimizedOrder(new SetVisitor(), 1, rows - 2, 1, columns - 2); + getVisitor = new GetVisitor(); + m.walkInColumnOrder(getVisitor, 1, rows - 2, 1, columns - 2); + assertEquals((rows - 2) * (columns - 2), getVisitor.getCount()); + for (int i = 0; i < rows; ++i) { + assertEquals(0.0, m.getEntry(i, 0), 0); + assertEquals(0.0, m.getEntry(i, columns - 1), 0); + } + for (int j = 0; j < columns; ++j) { + assertEquals(0.0, m.getEntry(0, j), 0); + assertEquals(0.0, m.getEntry(rows - 1, j), 0); + } + + } + + private static class SetVisitor extends DefaultRealMatrixChangingVisitor { + private static final long serialVersionUID = 1773444180892369386L; + public double visit(int i, int j, double value) { + return i + j / 1024.0; + } + } + + private static class GetVisitor extends DefaultRealMatrixPreservingVisitor { + private static final long serialVersionUID = -7745543227178932689L; + private int count = 0; + public void visit(int i, int j, double value) { + ++count; + assertEquals(i + j / 1024.0, value, 0.0); + } + public int getCount() { + return count; + } + }; + + //--------------- -----------------Protected methods + + /** verifies that two matrices are close (1-norm) */ + protected void assertClose(RealMatrix m, RealMatrix n, double tolerance) { + assertTrue(m.subtract(n).getNorm() < tolerance); + } + + /** verifies that two vectors are close (sup norm) */ + protected void assertClose(double[] m, double[] n, double tolerance) { + if (m.length != n.length) { + fail("vectors not same length"); + } + for (int i = 0; i < m.length; i++) { + assertEquals(m[i], n[i], tolerance); + } + } + + private RecursiveLayoutRealMatrix createRandomMatrix(Random r, int rows, int columns) { + RecursiveLayoutRealMatrix m = new RecursiveLayoutRealMatrix(rows, columns); + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < columns; ++j) { + m.setEntry(i, j, 200 * r.nextDouble() - 100); + } + } + return m; + } + +} +