diff --git a/src/java/org/apache/commons/math/linear/AbstractRealMatrix.java b/src/java/org/apache/commons/math/linear/AbstractRealMatrix.java index 980ae45e2..643cb02c2 100644 --- a/src/java/org/apache/commons/math/linear/AbstractRealMatrix.java +++ b/src/java/org/apache/commons/math/linear/AbstractRealMatrix.java @@ -606,28 +606,30 @@ public abstract class AbstractRealMatrix implements RealMatrix, Serializable { /** {@inheritDoc} */ public RealVector operate(final RealVector v) throws IllegalArgumentException { - - final int nRows = getRowDimension(); - final int nCols = getColumnDimension(); - if (v.getDimension() != nCols) { - throw MathRuntimeException.createIllegalArgumentException("vector length mismatch:" + - " got {0} but expected {1}", - new Object[] { - v.getDimension(), nCols - }); - } - - final double[] out = new double[nRows]; - for (int row = 0; row < nRows; ++row) { - double sum = 0; - for (int i = 0; i < nCols; ++i) { - sum += getEntry(row, i) * v.getEntry(i); + try { + return new RealVectorImpl(operate(((RealVectorImpl) v).getDataRef()), false); + } catch (ClassCastException cce) { + final int nRows = getRowDimension(); + final int nCols = getColumnDimension(); + if (v.getDimension() != nCols) { + throw MathRuntimeException.createIllegalArgumentException("vector length mismatch:" + + " got {0} but expected {1}", + new Object[] { + v.getDimension(), nCols + }); } - out[row] = sum; + + final double[] out = new double[nRows]; + for (int row = 0; row < nRows; ++row) { + double sum = 0; + for (int i = 0; i < nCols; ++i) { + sum += getEntry(row, i) * v.getEntry(i); + } + out[row] = sum; + } + + return new RealVectorImpl(out, false); } - - return new RealVectorImpl(out, false); - } /** {@inheritDoc} */ @@ -660,28 +662,32 @@ public abstract class AbstractRealMatrix implements RealMatrix, Serializable { /** {@inheritDoc} */ public RealVector preMultiply(final RealVector v) throws IllegalArgumentException { + try { + return new RealVectorImpl(preMultiply(((RealVectorImpl) v).getDataRef()), false); + } catch (ClassCastException cce) { - final int nRows = getRowDimension(); - final int nCols = getColumnDimension(); - if (v.getDimension() != nRows) { - throw MathRuntimeException.createIllegalArgumentException("vector length mismatch:" + - " got {0} but expected {1}", - new Object[] { - v.getDimension(), nRows - }); - } - - final double[] out = new double[nCols]; - for (int col = 0; col < nCols; ++col) { - double sum = 0; - for (int i = 0; i < nRows; ++i) { - sum += getEntry(i, col) * v.getEntry(i); + final int nRows = getRowDimension(); + final int nCols = getColumnDimension(); + if (v.getDimension() != nRows) { + throw MathRuntimeException.createIllegalArgumentException("vector length mismatch:" + + " got {0} but expected {1}", + new Object[] { + v.getDimension(), nRows + }); } - out[col] = sum; + + final double[] out = new double[nCols]; + for (int col = 0; col < nCols; ++col) { + double sum = 0; + for (int i = 0; i < nRows; ++i) { + sum += getEntry(i, col) * v.getEntry(i); + } + out[col] = sum; + } + + return new RealVectorImpl(out); + } - - return new RealVectorImpl(out); - } /** {@inheritDoc} */