added new cache-friendly specializations of get/set/operate/premultiply methods

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@729174 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Luc Maisonobe 2008-12-23 23:30:33 +00:00
parent ef8ee92f48
commit 0e1654b627
2 changed files with 624 additions and 34 deletions

View File

@ -44,7 +44,7 @@ public class DenseRealMatrix extends AbstractRealMatrix implements Serializable
private static final long serialVersionUID = 4991895511313664478L; private static final long serialVersionUID = 4991895511313664478L;
/** Block size. */ /** Block size. */
private static final int BLOCK_SIZE = 52; public static final int BLOCK_SIZE = 52;
/** Blocks of matrix entries. */ /** Blocks of matrix entries. */
private final double blocks[][]; private final double blocks[][];
@ -646,6 +646,408 @@ public class DenseRealMatrix extends AbstractRealMatrix implements Serializable
} }
} }
/** {@inheritDoc} */
public void setSubMatrix(final double[][] subMatrix, final int row, final int column)
throws MatrixIndexException {
// safety checks
final int refLength = subMatrix[0].length;
if (refLength < 1) {
throw MathRuntimeException.createIllegalArgumentException("matrix must have at least one column",
null);
}
final int endRow = row + subMatrix.length - 1;
final int endColumn = column + refLength - 1;
checkSubMatrixIndex(row, endRow, column, endColumn);
for (final double[] subRow : subMatrix) {
if (subRow.length != refLength) {
throw MathRuntimeException.createIllegalArgumentException("some rows have length {0} while others have length {1}",
new Object[] {
refLength, subRow.length
});
}
}
// compute blocks bounds
final int blockStartRow = row / BLOCK_SIZE;
final int blockEndRow = (endRow + BLOCK_SIZE) / BLOCK_SIZE;
final int blockStartColumn = column / BLOCK_SIZE;
final int blockEndColumn = (endColumn + BLOCK_SIZE) / BLOCK_SIZE;
// perform copy block-wise, to ensure good cache behavior
for (int iBlock = blockStartRow; iBlock < blockEndRow; ++iBlock) {
final int iHeight = blockHeight(iBlock);
final int firstRow = iBlock * BLOCK_SIZE;
final int iStart = Math.max(row, firstRow);
final int iEnd = Math.min(endRow + 1, firstRow + iHeight);
for (int jBlock = blockStartColumn; jBlock < blockEndColumn; ++jBlock) {
final int jWidth = blockWidth(jBlock);
final int firstColumn = jBlock * BLOCK_SIZE;
final int jStart = Math.max(column, firstColumn);
final int jEnd = Math.min(endColumn + 1, firstColumn + jWidth);
final int jLength = jEnd - jStart;
// handle one block, row by row
final double[] block = blocks[iBlock * blockColumns + jBlock];
for (int i = iStart; i < iEnd; ++i) {
System.arraycopy(subMatrix[i - row], jStart - column,
block, (i - firstRow) * jWidth + (jStart - firstColumn),
jLength);
}
}
}
}
/** {@inheritDoc} */
public RealMatrix getRowMatrix(final int row)
throws MatrixIndexException {
checkRowIndex(row);
final DenseRealMatrix out = new DenseRealMatrix(1, columns);
// perform copy block-wise, to ensure good cache behavior
final int iBlock = row / BLOCK_SIZE;
final int iRow = row - iBlock * BLOCK_SIZE;
int outBlockIndex = 0;
int outIndex = 0;
double[] outBlock = out.blocks[outBlockIndex];
for (int jBlock = 0; jBlock < blockColumns; ++jBlock) {
final int jWidth = blockWidth(jBlock);
final double[] block = blocks[iBlock * blockColumns + jBlock];
final int available = outBlock.length - outIndex;
if (jWidth > available) {
System.arraycopy(block, iRow * jWidth, outBlock, outIndex, available);
outBlock = out.blocks[++outBlockIndex];
System.arraycopy(block, iRow * jWidth, outBlock, 0, jWidth - available);
outIndex = jWidth - available;
} else {
System.arraycopy(block, iRow * jWidth, outBlock, outIndex, jWidth);
outIndex += jWidth;
}
}
return out;
}
/** {@inheritDoc} */
public void setRowMatrix(final int row, final RealMatrix matrix)
throws MatrixIndexException, InvalidMatrixException {
try {
setRowMatrix(row, (DenseRealMatrix) matrix);
} catch (ClassCastException cce) {
super.setRowMatrix(row, matrix);
}
}
/**
* Sets the entries in row number <code>row</code>
* as a row matrix. Row indices start at 0.
*
* @param row the row to be set
* @param matrix row matrix (must have one row and the same number of columns
* as the instance)
* @throws MatrixIndexException if the specified row index is invalid
* @throws InvalidMatrixException if the matrix dimensions do not match one
* instance row
*/
public void setRowMatrix(final int row, final DenseRealMatrix matrix)
throws MatrixIndexException, InvalidMatrixException {
checkRowIndex(row);
final int nCols = getColumnDimension();
if ((matrix.getRowDimension() != 1) ||
(matrix.getColumnDimension() != nCols)) {
throw new InvalidMatrixException("dimensions mismatch: got {0}x{1} but expected {2}x{3}",
new Object[] {
matrix.getRowDimension(),
matrix.getColumnDimension(),
1, nCols
});
}
// perform copy block-wise, to ensure good cache behavior
final int iBlock = row / BLOCK_SIZE;
final int iRow = row - iBlock * BLOCK_SIZE;
int mBlockIndex = 0;
int mIndex = 0;
double[] mBlock = matrix.blocks[mBlockIndex];
for (int jBlock = 0; jBlock < blockColumns; ++jBlock) {
final int jWidth = blockWidth(jBlock);
final double[] block = blocks[iBlock * blockColumns + jBlock];
final int available = mBlock.length - mIndex;
if (jWidth > available) {
System.arraycopy(mBlock, mIndex, block, iRow * jWidth, available);
mBlock = matrix.blocks[++mBlockIndex];
System.arraycopy(mBlock, 0, block, iRow * jWidth, jWidth - available);
mIndex = jWidth - available;
} else {
System.arraycopy(mBlock, mIndex, block, iRow * jWidth, jWidth);
mIndex += jWidth;
}
}
}
/** {@inheritDoc} */
public RealMatrix getColumnMatrix(final int column)
throws MatrixIndexException {
checkColumnIndex(column);
final DenseRealMatrix out = new DenseRealMatrix(rows, 1);
// perform copy block-wise, to ensure good cache behavior
final int jBlock = column / BLOCK_SIZE;
final int jColumn = column - jBlock * BLOCK_SIZE;
final int jWidth = blockWidth(jBlock);
int outBlockIndex = 0;
int outIndex = 0;
double[] outBlock = out.blocks[outBlockIndex];
for (int iBlock = 0; iBlock < blockRows; ++iBlock) {
final int iHeight = blockHeight(iBlock);
final double[] block = blocks[iBlock * blockColumns + jBlock];
for (int i = 0; i < iHeight; ++i) {
if (outIndex >= outBlock.length) {
outBlock = out.blocks[++outBlockIndex];
outIndex = 0;
}
outBlock[outIndex++] = block[i * jWidth + jColumn];
}
}
return out;
}
/** {@inheritDoc} */
public void setColumnMatrix(final int column, final RealMatrix matrix)
throws MatrixIndexException, InvalidMatrixException {
try {
setColumnMatrix(column, (DenseRealMatrix) matrix);
} catch (ClassCastException cce) {
super.setColumnMatrix(column, matrix);
}
}
/**
* Sets the entries in column number <code>column</code>
* as a column matrix. Column indices start at 0.
*
* @param column the column to be set
* @param matrix column matrix (must have one column and the same number of rows
* as the instance)
* @throws MatrixIndexException if the specified column index is invalid
* @throws InvalidMatrixException if the matrix dimensions do not match one
* instance column
*/
void setColumnMatrix(final int column, final DenseRealMatrix matrix)
throws MatrixIndexException, InvalidMatrixException {
checkColumnIndex(column);
final int nRows = getRowDimension();
if ((matrix.getRowDimension() != nRows) ||
(matrix.getColumnDimension() != 1)) {
throw new InvalidMatrixException("dimensions mismatch: got {0}x{1} but expected {2}x{3}",
new Object[] {
matrix.getRowDimension(),
matrix.getColumnDimension(),
nRows, 1
});
}
// perform copy block-wise, to ensure good cache behavior
final int jBlock = column / BLOCK_SIZE;
final int jColumn = column - jBlock * BLOCK_SIZE;
final int jWidth = blockWidth(jBlock);
int mBlockIndex = 0;
int mIndex = 0;
double[] mBlock = matrix.blocks[mBlockIndex];
for (int iBlock = 0; iBlock < blockRows; ++iBlock) {
final int iHeight = blockHeight(iBlock);
final double[] block = blocks[iBlock * blockColumns + jBlock];
for (int i = 0; i < iHeight; ++i) {
if (mIndex >= mBlock.length) {
mBlock = matrix.blocks[++mBlockIndex];
mIndex = 0;
}
block[i * jWidth + jColumn] = mBlock[mIndex++];
}
}
}
/** {@inheritDoc} */
public RealVector getRowVector(final int row)
throws MatrixIndexException {
checkRowIndex(row);
final RealVectorImpl out = new RealVectorImpl(columns);
// perform copy block-wise, to ensure good cache behavior
final int iBlock = row / BLOCK_SIZE;
final int iRow = row - iBlock * BLOCK_SIZE;
int outIndex = 0;
double[] outData = out.getDataRef();
for (int jBlock = 0; jBlock < blockColumns; ++jBlock) {
final int jWidth = blockWidth(jBlock);
final double[] block = blocks[iBlock * blockColumns + jBlock];
System.arraycopy(block, iRow * jWidth, outData, outIndex, jWidth);
outIndex += jWidth;
}
return out;
}
/** {@inheritDoc} */
public void setRowVector(final int row, final RealVector vector)
throws MatrixIndexException, InvalidMatrixException {
try {
setRow(row, ((RealVectorImpl) vector).getDataRef());
} catch (ClassCastException cce) {
super.setRowVector(row, vector);
}
}
/** {@inheritDoc} */
public RealVector getColumnVector(final int column)
throws MatrixIndexException {
checkColumnIndex(column);
final RealVectorImpl out = new RealVectorImpl(rows);
// perform copy block-wise, to ensure good cache behavior
final int jBlock = column / BLOCK_SIZE;
final int jColumn = column - jBlock * BLOCK_SIZE;
final int jWidth = blockWidth(jBlock);
int outIndex = 0;
double[] outData = out.getDataRef();
for (int iBlock = 0; iBlock < blockRows; ++iBlock) {
final int iHeight = blockHeight(iBlock);
final double[] block = blocks[iBlock * blockColumns + jBlock];
for (int i = 0; i < iHeight; ++i) {
outData[outIndex++] = block[i * jWidth + jColumn];
}
}
return out;
}
/** {@inheritDoc} */
public void setColumnVector(final int column, final RealVector vector)
throws MatrixIndexException, InvalidMatrixException {
try {
setColumn(column, ((RealVectorImpl) vector).getDataRef());
} catch (ClassCastException cce) {
super.setColumnVector(column, vector);
}
}
/** {@inheritDoc} */
public double[] getRow(final int row)
throws MatrixIndexException {
checkRowIndex(row);
final double[] out = new double[columns];
// perform copy block-wise, to ensure good cache behavior
final int iBlock = row / BLOCK_SIZE;
final int iRow = row - iBlock * BLOCK_SIZE;
int outIndex = 0;
for (int jBlock = 0; jBlock < blockColumns; ++jBlock) {
final int jWidth = blockWidth(jBlock);
final double[] block = blocks[iBlock * blockColumns + jBlock];
System.arraycopy(block, iRow * jWidth, out, outIndex, jWidth);
outIndex += jWidth;
}
return out;
}
/** {@inheritDoc} */
public void setRow(final int row, final double[] array)
throws MatrixIndexException, InvalidMatrixException {
checkRowIndex(row);
final int nCols = getColumnDimension();
if (array.length != nCols) {
throw new InvalidMatrixException("dimensions mismatch: got {0}x{1} but expected {2}x{3}",
new Object[] {
1, array.length,
1, nCols
});
}
// perform copy block-wise, to ensure good cache behavior
final int iBlock = row / BLOCK_SIZE;
final int iRow = row - iBlock * BLOCK_SIZE;
int outIndex = 0;
for (int jBlock = 0; jBlock < blockColumns; ++jBlock) {
final int jWidth = blockWidth(jBlock);
final double[] block = blocks[iBlock * blockColumns + jBlock];
System.arraycopy(array, outIndex, block, iRow * jWidth, jWidth);
outIndex += jWidth;
}
}
/** {@inheritDoc} */
public double[] getColumn(final int column)
throws MatrixIndexException {
checkColumnIndex(column);
final double[] out = new double[rows];
// perform copy block-wise, to ensure good cache behavior
final int jBlock = column / BLOCK_SIZE;
final int jColumn = column - jBlock * BLOCK_SIZE;
final int jWidth = blockWidth(jBlock);
int outIndex = 0;
for (int iBlock = 0; iBlock < blockRows; ++iBlock) {
final int iHeight = blockHeight(iBlock);
final double[] block = blocks[iBlock * blockColumns + jBlock];
for (int i = 0; i < iHeight; ++i) {
out[outIndex++] = block[i * jWidth + jColumn];
}
}
return out;
}
/** {@inheritDoc} */
public void setColumn(final int column, final double[] array)
throws MatrixIndexException, InvalidMatrixException {
checkColumnIndex(column);
final int nRows = getRowDimension();
if (array.length != nRows) {
throw new InvalidMatrixException("dimensions mismatch: got {0}x{1} but expected {2}x{3}",
new Object[] {
array.length, 1,
nRows, 1
});
}
// perform copy block-wise, to ensure good cache behavior
final int jBlock = column / BLOCK_SIZE;
final int jColumn = column - jBlock * BLOCK_SIZE;
final int jWidth = blockWidth(jBlock);
int outIndex = 0;
for (int iBlock = 0; iBlock < blockRows; ++iBlock) {
final int iHeight = blockHeight(iBlock);
final double[] block = blocks[iBlock * blockColumns + jBlock];
for (int i = 0; i < iHeight; ++i) {
block[i * jWidth + jColumn] = array[outIndex++];
}
}
}
/** {@inheritDoc} */ /** {@inheritDoc} */
public double getEntry(final int row, final int column) public double getEntry(final int row, final int column)
throws MatrixIndexException { throws MatrixIndexException {
@ -768,20 +1170,21 @@ public class DenseRealMatrix extends AbstractRealMatrix implements Serializable
public double[] operate(final double[] v) public double[] operate(final double[] v)
throws IllegalArgumentException { throws IllegalArgumentException {
final int nRows = this.getRowDimension(); if (v.length != columns) {
final int nCols = this.getColumnDimension(); throw MathRuntimeException.createIllegalArgumentException("vector length mismatch:" +
if (v.length != nCols) { " got {0} but expected {1}",
throw new IllegalArgumentException("vector has wrong length"); new Object[] {
v.length, columns
});
} }
final double[] out = new double[nRows]; final double[] out = new double[rows];
// perform multiplication block-wise, to ensure good cache behavior // perform multiplication block-wise, to ensure good cache behavior
int blockIndex = 0;
for (int iBlock = 0; iBlock < blockRows; ++iBlock) { for (int iBlock = 0; iBlock < blockRows; ++iBlock) {
for (int jBlock = 0; jBlock < blockColumns; ++jBlock) {
final double[] block = blocks[blockIndex];
final int pStart = iBlock * BLOCK_SIZE; final int pStart = iBlock * BLOCK_SIZE;
final int pEnd = Math.min(pStart + BLOCK_SIZE, rows); final int pEnd = Math.min(pStart + BLOCK_SIZE, rows);
for (int jBlock = 0; jBlock < blockColumns; ++jBlock) {
final double[] block = blocks[iBlock * blockColumns + jBlock];
final int qStart = jBlock * BLOCK_SIZE; final int qStart = jBlock * BLOCK_SIZE;
final int qEnd = Math.min(qStart + BLOCK_SIZE, columns); final int qEnd = Math.min(qStart + BLOCK_SIZE, columns);
for (int p = pStart, k = 0; p < pEnd; ++p) { for (int p = pStart, k = 0; p < pEnd; ++p) {
@ -800,7 +1203,6 @@ public class DenseRealMatrix extends AbstractRealMatrix implements Serializable
} }
out[p] += sum; out[p] += sum;
} }
++blockIndex;
} }
} }
@ -809,25 +1211,53 @@ public class DenseRealMatrix extends AbstractRealMatrix implements Serializable
} }
/** {@inheritDoc} */ /** {@inheritDoc} */
public RealVector operate(final RealVector v) public double[] preMultiply(final double[] v)
throws IllegalArgumentException { throws IllegalArgumentException {
try {
return operate((RealVectorImpl) v); if (v.length != rows) {
} catch (ClassCastException cce) { throw MathRuntimeException.createIllegalArgumentException("vector length mismatch:" +
return super.operate(v); " got {0} but expected {1}",
new Object[] {
v.length, rows
});
}
final double[] out = new double[columns];
// perform multiplication block-wise, to ensure good cache behavior
for (int jBlock = 0; jBlock < blockColumns; ++jBlock) {
final int jWidth = blockWidth(jBlock);
final int jWidth2 = jWidth + jWidth;
final int jWidth3 = jWidth2 + jWidth;
final int jWidth4 = jWidth3 + jWidth;
final int qStart = jBlock * BLOCK_SIZE;
final int qEnd = Math.min(qStart + BLOCK_SIZE, columns);
for (int iBlock = 0; iBlock < blockRows; ++iBlock) {
final double[] block = blocks[iBlock * blockColumns + jBlock];
final int pStart = iBlock * BLOCK_SIZE;
final int pEnd = Math.min(pStart + BLOCK_SIZE, rows);
for (int q = qStart; q < qEnd; ++q) {
int k = q - qStart;
double sum = 0;
int p = pStart;
while (p < pEnd - 3) {
sum += block[k] * v[p] +
block[k + jWidth] * v[p + 1] +
block[k + jWidth2] * v[p + 2] +
block[k + jWidth3] * v[p + 3];
k += jWidth4;
p += 4;
}
while (p < pEnd) {
sum += block[k] * v[p++];
k += jWidth;
}
out[q] += sum;
}
} }
} }
/** return out;
* Returns the result of multiplying this by the vector <code>v</code>.
*
* @param v the vector to operate on
* @return this*v
* @throws IllegalArgumentException if columnDimension != v.size()
*/
public RealVectorImpl operate(final RealVectorImpl v)
throws IllegalArgumentException {
return new RealVectorImpl(operate(v.getDataRef()), false);
} }
/** /**

View File

@ -16,6 +16,7 @@
*/ */
package org.apache.commons.math.linear; package org.apache.commons.math.linear;
import java.util.Arrays;
import java.util.Random; import java.util.Random;
import junit.framework.Test; import junit.framework.Test;
@ -332,6 +333,32 @@ public final class DenseRealMatrixTest extends TestCase {
} }
} }
public void testOperateLarge() {
int p = (7 * DenseRealMatrix.BLOCK_SIZE) / 2;
int q = (5 * DenseRealMatrix.BLOCK_SIZE) / 2;
int r = 3 * DenseRealMatrix.BLOCK_SIZE;
Random random = new Random(111007463902334l);
RealMatrix m1 = createRandomMatrix(random, p, q);
RealMatrix m2 = createRandomMatrix(random, q, r);
RealMatrix m1m2 = m1.multiply(m2);
for (int i = 0; i < r; ++i) {
checkArrays(m1m2.getColumn(i), m1.operate(m2.getColumn(i)));
}
}
public void testOperatePremultiplyLarge() {
int p = (7 * DenseRealMatrix.BLOCK_SIZE) / 2;
int q = (5 * DenseRealMatrix.BLOCK_SIZE) / 2;
int r = 3 * DenseRealMatrix.BLOCK_SIZE;
Random random = new Random(111007463902334l);
RealMatrix m1 = createRandomMatrix(random, p, q);
RealMatrix m2 = createRandomMatrix(random, q, r);
RealMatrix m1m2 = m1.multiply(m2);
for (int i = 0; i < p; ++i) {
checkArrays(m1m2.getRow(i), m2.preMultiply(m1.getRow(i)));
}
}
/** test issue MATH-209 */ /** test issue MATH-209 */
public void testMath209() { public void testMath209() {
RealMatrix a = new DenseRealMatrix(new double[][] { RealMatrix a = new DenseRealMatrix(new double[][] {
@ -507,14 +534,31 @@ public final class DenseRealMatrixTest extends TestCase {
} }
} }
public void testGetSetMatrixLarge() {
int n = 3 * DenseRealMatrix.BLOCK_SIZE;
RealMatrix m = new DenseRealMatrix(n, n);
RealMatrix sub = new DenseRealMatrix(n - 4, n - 4).scalarAdd(1);
m.setSubMatrix(sub.getData(), 2, 2);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
if ((i < 2) || (i > n - 3) || (j < 2) || (j > n - 3)) {
assertEquals(0.0, m.getEntry(i, j), 0.0);
} else {
assertEquals(1.0, m.getEntry(i, j), 0.0);
}
}
}
assertEquals(sub, m.getSubMatrix(2, n - 3, 2, n - 3));
}
public void testGetRowMatrix() { public void testGetRowMatrix() {
RealMatrix m = new DenseRealMatrix(subTestData); RealMatrix m = new DenseRealMatrix(subTestData);
RealMatrix mRow0 = new DenseRealMatrix(subRow0); RealMatrix mRow0 = new DenseRealMatrix(subRow0);
RealMatrix mRow3 = new DenseRealMatrix(subRow3); RealMatrix mRow3 = new DenseRealMatrix(subRow3);
assertEquals("Row0", mRow0, assertEquals("Row0", mRow0, m.getRowMatrix(0));
m.getRowMatrix(0)); assertEquals("Row3", mRow3, m.getRowMatrix(3));
assertEquals("Row3", mRow3,
m.getRowMatrix(3));
try { try {
m.getRowMatrix(-1); m.getRowMatrix(-1);
fail("Expecting MatrixIndexException"); fail("Expecting MatrixIndexException");
@ -549,6 +593,25 @@ public final class DenseRealMatrixTest extends TestCase {
} }
} }
public void testGetSetRowMatrixLarge() {
int n = 3 * DenseRealMatrix.BLOCK_SIZE;
RealMatrix m = new DenseRealMatrix(n, n);
RealMatrix sub = new DenseRealMatrix(1, n).scalarAdd(1);
m.setRowMatrix(2, sub);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
if (i != 2) {
assertEquals(0.0, m.getEntry(i, j), 0.0);
} else {
assertEquals(1.0, m.getEntry(i, j), 0.0);
}
}
}
assertEquals(sub, m.getRowMatrix(2));
}
public void testGetColumnMatrix() { public void testGetColumnMatrix() {
RealMatrix m = new DenseRealMatrix(subTestData); RealMatrix m = new DenseRealMatrix(subTestData);
RealMatrix mColumn1 = new DenseRealMatrix(subColumn1); RealMatrix mColumn1 = new DenseRealMatrix(subColumn1);
@ -589,6 +652,25 @@ public final class DenseRealMatrixTest extends TestCase {
} }
} }
public void testGetSetColumnMatrixLarge() {
int n = 3 * DenseRealMatrix.BLOCK_SIZE;
RealMatrix m = new DenseRealMatrix(n, n);
RealMatrix sub = new DenseRealMatrix(n, 1).scalarAdd(1);
m.setColumnMatrix(2, sub);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
if (j != 2) {
assertEquals(0.0, m.getEntry(i, j), 0.0);
} else {
assertEquals(1.0, m.getEntry(i, j), 0.0);
}
}
}
assertEquals(sub, m.getColumnMatrix(2));
}
public void testGetRowVector() { public void testGetRowVector() {
RealMatrix m = new DenseRealMatrix(subTestData); RealMatrix m = new DenseRealMatrix(subTestData);
RealVector mRow0 = new RealVectorImpl(subRow0[0]); RealVector mRow0 = new RealVectorImpl(subRow0[0]);
@ -629,6 +711,25 @@ public final class DenseRealMatrixTest extends TestCase {
} }
} }
public void testGetSetRowVectorLarge() {
int n = 3 * DenseRealMatrix.BLOCK_SIZE;
RealMatrix m = new DenseRealMatrix(n, n);
RealVector sub = new RealVectorImpl(n, 1.0);
m.setRowVector(2, sub);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
if (i != 2) {
assertEquals(0.0, m.getEntry(i, j), 0.0);
} else {
assertEquals(1.0, m.getEntry(i, j), 0.0);
}
}
}
assertEquals(sub, m.getRowVector(2));
}
public void testGetColumnVector() { public void testGetColumnVector() {
RealMatrix m = new DenseRealMatrix(subTestData); RealMatrix m = new DenseRealMatrix(subTestData);
RealVector mColumn1 = columnToVector(subColumn1); RealVector mColumn1 = columnToVector(subColumn1);
@ -669,6 +770,25 @@ public final class DenseRealMatrixTest extends TestCase {
} }
} }
public void testGetSetColumnVectorLarge() {
int n = 3 * DenseRealMatrix.BLOCK_SIZE;
RealMatrix m = new DenseRealMatrix(n, n);
RealVector sub = new RealVectorImpl(n, 1.0);
m.setColumnVector(2, sub);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
if (j != 2) {
assertEquals(0.0, m.getEntry(i, j), 0.0);
} else {
assertEquals(1.0, m.getEntry(i, j), 0.0);
}
}
}
assertEquals(sub, m.getColumnVector(2));
}
private RealVector columnToVector(double[][] column) { private RealVector columnToVector(double[][] column) {
double[] data = new double[column.length]; double[] data = new double[column.length];
for (int i = 0; i < data.length; ++i) { for (int i = 0; i < data.length; ++i) {
@ -714,6 +834,26 @@ public final class DenseRealMatrixTest extends TestCase {
} }
} }
public void testGetSetRowLarge() {
int n = 3 * DenseRealMatrix.BLOCK_SIZE;
RealMatrix m = new DenseRealMatrix(n, n);
double[] sub = new double[n];
Arrays.fill(sub, 1.0);
m.setRow(2, sub);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
if (i != 2) {
assertEquals(0.0, m.getEntry(i, j), 0.0);
} else {
assertEquals(1.0, m.getEntry(i, j), 0.0);
}
}
}
checkArrays(sub, m.getRow(2));
}
public void testGetColumn() { public void testGetColumn() {
RealMatrix m = new DenseRealMatrix(subTestData); RealMatrix m = new DenseRealMatrix(subTestData);
double[] mColumn1 = columnToArray(subColumn1); double[] mColumn1 = columnToArray(subColumn1);
@ -754,6 +894,26 @@ public final class DenseRealMatrixTest extends TestCase {
} }
} }
public void testGetSetColumnLarge() {
int n = 3 * DenseRealMatrix.BLOCK_SIZE;
RealMatrix m = new DenseRealMatrix(n, n);
double[] sub = new double[n];
Arrays.fill(sub, 1.0);
m.setColumn(2, sub);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
if (j != 2) {
assertEquals(0.0, m.getEntry(i, j), 0.0);
} else {
assertEquals(1.0, m.getEntry(i, j), 0.0);
}
}
}
checkArrays(sub, m.getColumn(2));
}
private double[] columnToArray(double[][] column) { private double[] columnToArray(double[][] column) {
double[] data = new double[column.length]; double[] data = new double[column.length];
for (int i = 0; i < data.length; ++i) { for (int i = 0; i < data.length; ++i) {