Some changes to the interface of iterative linear solvers. Replaced the confusing boolean param inPlace in the solve() methods by two sets of methods: solve() and solveInPlace().

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1178073 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Sebastien Brisard 2011-10-01 19:09:00 +00:00
parent bbe44629cc
commit 1824dd31e6
4 changed files with 200 additions and 101 deletions

View File

@ -61,8 +61,9 @@ import org.apache.commons.math.util.IterationManager;
* <dt><a id="BARR1994">Barret et al. (1994)</a></dt>
* <dd>R. Barrett, M. Berry, T. F. Chan, J. Demmel, J. M. Donato, J. Dongarra,
* V. Eijkhout, R. Pozo, C. Romine and H. Van der Vorst,
* <em>Templates for the Solution of Linear Systems: Building Blocks for
* Iterative Methods</em>, SIAM</dd>
* <a href="http://www.netlib.org/linalg/html_templates/Templates.html"><em>
* Templates for the Solution of Linear Systems: Building Blocks for Iterative
* Methods</em></a>, SIAM</dd>
* <dt><a id="STRA2002">Strakos and Tichy (2002)
* <dt>
* <dd>Z. Strakos and P. Tichy, <a
@ -82,7 +83,8 @@ public class ConjugateGradient
* The type of all events fired by this implementation of the Conjugate
* Gradient method.
*
* @version $Id$
* @version $Id: ConjugateGradient.java 1175404 2011-09-25 14:48:18Z
* celestin $
*/
public abstract static class ConjugateGradientEvent
extends IterativeLinearSolverEvent
@ -95,7 +97,7 @@ public class ConjugateGradient
* Creates a new instance of this class.
*
* @param source The iterative algorithm on which the event initially
* occurred.
* occurred.
*/
public ConjugateGradientEvent(final Object source) {
super(source);
@ -124,7 +126,7 @@ public class ConjugateGradient
* @param maxIterations Maximum number of iterations.
* @param delta &delta; parameter for the default stopping criterion.
* @param check {@code true} if positive definiteness of both matrix and
* preconditioner should be checked.
* preconditioner should be checked.
*/
public ConjugateGradient(final int maxIterations, final double delta,
final boolean check) {
@ -140,7 +142,7 @@ public class ConjugateGradient
* @param manager Custom iteration manager.
* @param delta &delta; parameter for the default stopping criterion.
* @param check {@code true} if positive definiteness of both matrix and
* preconditioner should be checked.
* preconditioner should be checked.
*/
public ConjugateGradient(final IterationManager manager,
final double delta, final boolean check) {
@ -161,13 +163,12 @@ public class ConjugateGradient
/** {@inheritDoc} */
@Override
public RealVector solve(final RealLinearOperator a,
final InvertibleRealLinearOperator m,
final RealVector b, final RealVector x0,
final boolean inPlace)
public RealVector solveInPlace(final RealLinearOperator a,
final InvertibleRealLinearOperator m,
final RealVector b, final RealVector x0)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException, MaxCountExceededException {
checkParameters(a, m, b, x0, inPlace);
checkParameters(a, m, b, x0);
final IterationManager manager = getIterationManager();
// Initialization of default stopping criterion
manager.resetIterationCount();
@ -176,16 +177,7 @@ public class ConjugateGradient
// p and x are constructed as copies of x0, since presumably, the type
// of x is optimized for the calculation of the matrix-vector product
// A.x.
final RealVector x;
if (inPlace) {
x = x0;
} else {
if (x0 != null) {
x = x0.copy();
} else {
x = new ArrayRealVector(a.getColumnDimension());
}
}
final RealVector x = x0;
final RealVector p = x.copy();
RealVector q = a.operate(p);
manager.incrementIterationCount();

View File

@ -49,35 +49,36 @@ public abstract class IterativeLinearSolver {
* Creates a new instance of this class, with custom iteration manager.
*
* @param manager Custom iteration manager.
* @throws NullArgumentException if {@code manager} is {@code null}.
*/
public IterativeLinearSolver(final IterationManager manager) {
public IterativeLinearSolver(final IterationManager manager)
throws NullArgumentException {
MathUtils.checkNotNull(manager);
this.manager = manager;
}
/**
* Performs all dimension checks on the parameters of
* {@link #solve(RealLinearOperator, RealVector, RealVector, boolean)}, and
* throws an exception if one of the checks fails.
* {@link #solve(RealLinearOperator, RealVector, RealVector) solve} and
* {@link #solveInPlace(RealLinearOperator, RealVector, RealVector) solveInPlace},
* and throws an exception if one of the checks fails.
*
* @param a Linear operator A of the system.
* @param b Right-hand side vector.
* @param x0 Initial guess of the solution (can be {@code null} if
* {@code inPlace} is set to {@code false}).
* @param inPlace {@code true} if the initial guess is to be updated with
* the current estimate of the solution.
* @param x0 Initial guess of the solution.
* @throws NullArgumentException if one of the parameters is {@code null}.
* @throws NonSquareLinearOperatorException if {@code a} is not square.
* @throws DimensionMismatchException if {@code b} or {@code x0} have
* dimensions inconsistent with {@code a}.
* dimensions inconsistent with {@code a}.
*/
protected static void checkParameters(final RealLinearOperator a,
final RealVector b,
final RealVector x0,
final boolean inPlace)
final RealVector x0)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException {
MathUtils.checkNotNull(a);
MathUtils.checkNotNull(b);
MathUtils.checkNotNull(x0);
if (a.getRowDimension() != a.getColumnDimension()) {
throw new NonSquareLinearOperatorException(a.getRowDimension(),
a.getColumnDimension());
@ -86,10 +87,7 @@ public abstract class IterativeLinearSolver {
throw new DimensionMismatchException(b.getDimension(),
a.getRowDimension());
}
if (inPlace) {
MathUtils.checkNotNull(x0);
}
if ((x0 != null) && (x0.getDimension() != a.getColumnDimension())) {
if (x0.getDimension() != a.getColumnDimension()) {
throw new DimensionMismatchException(x0.getDimension(),
a.getColumnDimension());
}
@ -106,28 +104,70 @@ public abstract class IterativeLinearSolver {
/**
* Returns an estimate of the solution to the linear system A &middot; x =
* b. If no initial estimate of the solution is provided, (0, &hellip;, 0)
* is assumed.
* b.
*
* @param a Linear operator A of the system.
* @param b Right-hand side vector.
* @param x0 Initial guess of the solution (can be {@code null} if
* {@code inPlace} is set to {@code false}).
* @param inPlace {@code true} if the initial guess is to be updated with
* the current estimate of the solution.
* @return A reference to {@code x0} (shallow copy) if {@code inPlace} was
* set to {@code true}. Otherwise, a new vector containing the
* solution.
* @return A new vector containing the solution.
* @throws NullArgumentException if one of the parameters is {@code null}.
* @throws NonSquareLinearOperatorException if {@code a} is not square.
* @throws DimensionMismatchException if {@code b} has dimensions
* inconsistent with {@code a}.
* @throws MaxCountExceededException at exhaustion of the iteration count,
* unless a custom {@link MaxCountExceededCallback callback} has been set at
* construction.
*/
public RealVector solve(RealLinearOperator a, RealVector b)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException, MaxCountExceededException {
MathUtils.checkNotNull(a);
final RealVector x = new ArrayRealVector(a.getColumnDimension());
x.set(0.);
return solveInPlace(a, b, x);
}
/**
* Returns an estimate of the solution to the linear system A &middot; x =
* b.
*
* @param a Linear operator A of the system.
* @param b Right-hand side vector.
* @param x0 Initial guess of the solution.
* @return A new vector containing the solution.
* @throws NullArgumentException if one of the parameters is {@code null}.
* @throws NonSquareLinearOperatorException if {@code a} is not square.
* @throws DimensionMismatchException if {@code b} or {@code x0} have
* dimensions inconsistent with {@code a}.
* dimensions inconsistent with {@code a}.
* @throws MaxCountExceededException at exhaustion of the iteration count,
* unless a custom {@link MaxCountExceededCallback callback} has
* been set at construction.
* unless a custom {@link MaxCountExceededCallback callback} has been set at
* construction.
*/
public abstract RealVector solve(RealLinearOperator a, RealVector b,
RealVector x0, boolean inPlace)
public RealVector solve(RealLinearOperator a, RealVector b, RealVector x0)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException, MaxCountExceededException {
MathUtils.checkNotNull(x0);
return solveInPlace(a, b, x0.copy());
}
/**
* Returns an estimate of the solution to the linear system A &middot; x =
* b. The solution is computed in-place (initial guess is modified).
*
* @param a Linear operator A of the system.
* @param b Right-hand side vector.
* @param x0 Initial guess of the solution.
* @return A reference to {@code x0} (shallow copy) updated with the
* solution.
* @throws NullArgumentException if one of the parameters is {@code null}.
* @throws NonSquareLinearOperatorException if {@code a} is not square.
* @throws DimensionMismatchException if {@code b} or {@code x0} have
* dimensions inconsistent with {@code a}.
* @throws MaxCountExceededException at exhaustion of the iteration count,
* unless a custom {@link MaxCountExceededCallback callback} has been set at
* construction.
*/
public abstract RealVector solveInPlace(RealLinearOperator a, RealVector b,
RealVector x0)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException, MaxCountExceededException;
}

View File

@ -20,6 +20,7 @@ import org.apache.commons.math.exception.DimensionMismatchException;
import org.apache.commons.math.exception.MaxCountExceededException;
import org.apache.commons.math.exception.NullArgumentException;
import org.apache.commons.math.util.IterationManager;
import org.apache.commons.math.util.MathUtils;
/**
* This abstract class defines preconditioned iterative solvers. When A is
@ -48,41 +49,88 @@ public abstract class PreconditionedIterativeLinearSolver
* Creates a new instance of this class, with custom iteration manager.
*
* @param manager Custom iteration manager.
* @throws NullArgumentException if {@code manager} is {@code null}.
*/
public PreconditionedIterativeLinearSolver(final IterationManager manager) {
public PreconditionedIterativeLinearSolver(final IterationManager manager)
throws NullArgumentException {
super(manager);
}
/**
* Returns an estimate of the solution to the linear system A &middot; x =
* b.
*
* @param a Linear operator A of the system.
* @param m Preconditioner (can be {@code null}).
* @param b Right-hand side vector.
* @param x0 Initial guess of the solution.
* @return A new vector containing the solution.
* @throws NullArgumentException if one of the parameters is {@code null}.
* @throws NonSquareLinearOperatorException if {@code a} or {@code m} is not
* square.
* @throws DimensionMismatchException if {@code m}, {@code b} or {@code x0}
* have dimensions inconsistent with {@code a}.
* @throws MaxCountExceededException at exhaustion of the iteration count,
* unless a custom {@link MaxCountExceededCallback callback} has been set at
* construction.
*/
public RealVector solve(final RealLinearOperator a,
final InvertibleRealLinearOperator m,
final RealVector b, final RealVector x0)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException, MaxCountExceededException {
MathUtils.checkNotNull(x0);
return solveInPlace(a, m, b, x0.copy());
}
/** {@inheritDoc} */
@Override
public RealVector solve(final RealLinearOperator a, final RealVector b)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException, MaxCountExceededException {
MathUtils.checkNotNull(a);
final RealVector x = new ArrayRealVector(a.getColumnDimension());
x.set(0.);
return solveInPlace(a, null, b, x);
}
/** {@inheritDoc} */
@Override
public RealVector solve(final RealLinearOperator a, final RealVector b,
final RealVector x0)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException, MaxCountExceededException {
MathUtils.checkNotNull(x0);
return solveInPlace(a, null, b, x0.copy());
}
/**
* Performs all dimension checks on the parameters of
* {@link #solve(RealLinearOperator, InvertibleRealLinearOperator, RealVector, RealVector, boolean)}
* {@link #solve(RealLinearOperator, InvertibleRealLinearOperator, RealVector, RealVector) solve}
* and
* {@link #solveInPlace(RealLinearOperator, InvertibleRealLinearOperator, RealVector, RealVector) solveInPlace}
* , and throws an exception if one of the checks fails.
*
* @param a Linear operator A of the system.
* @param m Preconditioner (can be {@code null}).
* @param b Right-hand side vector.
* @param x0 Initial guess of the solution (can be {@code null} if
* {@code inPlace} is set to {@code false}).
* @param inPlace {@code true} if the initial guess is to be updated with
* the current estimate of the solution.
* @param x0 Initial guess of the solution.
* @throws NullArgumentException if one of the parameters is {@code null}.
* @throws NonSquareLinearOperatorException if {@code a} or {@code m} is not
* square.
* @throws DimensionMismatchException if {@code m}, {@code b} or {@code x}
* have dimensions inconsistent with {@code a}.
* square.
* @throws DimensionMismatchException if {@code m}, {@code b} or {@code x0}
* have dimensions inconsistent with {@code a}.
*/
protected static void checkParameters(final RealLinearOperator a,
final InvertibleRealLinearOperator m,
final RealVector b,
final RealVector x0,
final boolean inPlace)
final RealVector x0)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException {
checkParameters(a, b, x0, inPlace);
checkParameters(a, b, x0);
if (m != null) {
if (m.getColumnDimension() != m.getRowDimension()) {
throw new NonSquareLinearOperatorException(
m.getColumnDimension(),
throw new NonSquareLinearOperatorException(m.getColumnDimension(),
m.getRowDimension());
}
if (m.getRowDimension() != a.getRowDimension()) {
@ -94,42 +142,61 @@ public abstract class PreconditionedIterativeLinearSolver
/**
* Returns an estimate of the solution to the linear system A &middot; x =
* b. If no initial estimate of the solution is provided, (0, &hellip;, 0)
* is assumed.
* b.
*
* @param a Linear operator A of the system.
* @param m Preconditioner (can be {@code null}).
* @param b Right-hand side vector.
* @param x0 Initial guess of the solution (can be {@code null} if
* {@code inPlace} is set to {@code false}).
* @param inPlace {@code true} if the initial guess is to be updated with
* the current estimate of the solution.
* @return A reference to {@code x0} (shallow copy) if {@code update} was
* set to {@code true}. Otherwise, a new vector containing the
* solution.
* @return A new vector containing the solution.
* @throws NullArgumentException if one of the parameters is {@code null}.
* @throws NonSquareLinearOperatorException if {@code a} or {@code m} is not
* square.
* @throws DimensionMismatchException if {@code m}, {@code b} or {@code x}
* have dimensions inconsistent with {@code a}.
* square.
* @throws DimensionMismatchException if {@code m} or {@code b} have
* dimensions inconsistent with {@code a}.
* @throws MaxCountExceededException at exhaustion of the iteration count,
* unless a custom {@link MaxCountExceededCallback callback} has
* been set at construction.
* unless a custom {@link MaxCountExceededCallback callback} has been set at
* construction.
*/
public abstract RealVector solve(RealLinearOperator a,
InvertibleRealLinearOperator m,
RealVector b, RealVector x0,
final boolean inPlace)
public RealVector solve(RealLinearOperator a,
InvertibleRealLinearOperator m, RealVector b)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException, MaxCountExceededException {
MathUtils.checkNotNull(a);
final RealVector x = new ArrayRealVector(a.getColumnDimension());
return solveInPlace(a, m, b, x);
}
/**
* Returns an estimate of the solution to the linear system A &middot; x =
* b. The solution is computed in-place (initial guess is modified).
*
* @param a Linear operator A of the system.
* @param m Preconditioner (can be {@code null}).
* @param b Right-hand side vector.
* @param x0 Initial guess of the solution.
* @return A reference to {@code x0} (shallow copy) updated with the
* solution.
* @throws NullArgumentException if one of the parameters is {@code null}.
* @throws NonSquareLinearOperatorException if {@code a} or {@code m} is not
* square.
* @throws DimensionMismatchException if {@code m}, {@code b} or {@code x0}
* have dimensions inconsistent with {@code a}.
* @throws MaxCountExceededException at exhaustion of the iteration count,
* unless a custom {@link MaxCountExceededCallback callback} has been set at
* construction.
*/
public abstract RealVector solveInPlace(RealLinearOperator a,
InvertibleRealLinearOperator m,
RealVector b, RealVector x0)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException, MaxCountExceededException;
/** {@inheritDoc} */
@Override
public RealVector solve(final RealLinearOperator a, final RealVector b,
final RealVector x0, final boolean inPlace)
public RealVector solveInPlace(final RealLinearOperator a,
final RealVector b, final RealVector x0)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException, MaxCountExceededException {
checkParameters(a, b, x0, inPlace);
return solve(a, null, b, x0, inPlace);
return solveInPlace(a, null, b, x0);
}
}

View File

@ -32,7 +32,7 @@ public class ConjugateGradientTest {
solver = new ConjugateGradient(10, 0., false);
final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
final ArrayRealVector x = new ArrayRealVector(a.getColumnDimension());
solver.solve(a, b, x, false);
solver.solve(a, b, x);
}
@Test(expected = DimensionMismatchException.class)
@ -42,7 +42,7 @@ public class ConjugateGradientTest {
solver = new ConjugateGradient(10, 0., false);
final ArrayRealVector b = new ArrayRealVector(2);
final ArrayRealVector x = new ArrayRealVector(3);
solver.solve(a, b, x, false);
solver.solve(a, b, x);
}
@Test(expected = DimensionMismatchException.class)
@ -52,7 +52,7 @@ public class ConjugateGradientTest {
solver = new ConjugateGradient(10, 0., false);
final ArrayRealVector b = new ArrayRealVector(3);
final ArrayRealVector x = new ArrayRealVector(2);
solver.solve(a, b, x, false);
solver.solve(a, b, x);
}
@Test(expected = NonPositiveDefiniteLinearOperatorException.class)
@ -68,7 +68,7 @@ public class ConjugateGradientTest {
b.setEntry(0, -1.);
b.setEntry(1, -1.);
final ArrayRealVector x = new ArrayRealVector(2);
solver.solve(a, b, x, false);
solver.solve(a, b, x);
}
@Test
@ -83,7 +83,7 @@ public class ConjugateGradientTest {
for (int j = 0; j < n; j++) {
b.set(0.);
b.setEntry(j, 1.);
final RealVector x = solver.solve(a, b, null, false);
final RealVector x = solver.solve(a, b);
for (int i = 0; i < n; i++) {
final double actual = x.getEntry(i);
final double expected = ainv.getEntry(i, j);
@ -108,7 +108,7 @@ public class ConjugateGradientTest {
b.setEntry(j, 1.);
final RealVector x0 = new ArrayRealVector(n);
x0.set(1.);
final RealVector x = solver.solve(a, b, x0, true);
final RealVector x = solver.solveInPlace(a, b, x0);
Assert.assertSame("x should be a reference to x0", x0, x);
for (int i = 0; i < n; i++) {
final double actual = x.getEntry(i);
@ -134,7 +134,7 @@ public class ConjugateGradientTest {
b.setEntry(j, 1.);
final RealVector x0 = new ArrayRealVector(n);
x0.set(1.);
final RealVector x = solver.solve(a, b, x0, false);
final RealVector x = solver.solve(a, b, x0);
Assert.assertNotSame("x should not be a reference to x0", x0, x);
for (int i = 0; i < n; i++) {
final double actual = x.getEntry(i);
@ -186,7 +186,7 @@ public class ConjugateGradientTest {
b.set(0.);
b.setEntry(j, 1.);
final RealVector x = solver.solve(a, b, null, false);
final RealVector x = solver.solve(a, b);
final RealVector y = a.operate(x);
for (int i = 0; i < n; i++) {
final double actual = b.getEntry(i) - y.getEntry(i);
@ -228,7 +228,7 @@ public class ConjugateGradientTest {
final PreconditionedIterativeLinearSolver solver;
solver = new ConjugateGradient(10, 0d, false);
final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
solver.solve(a, m, b, null, false);
solver.solve(a, m, b);
}
@Test(expected = DimensionMismatchException.class)
@ -260,7 +260,7 @@ public class ConjugateGradientTest {
final PreconditionedIterativeLinearSolver solver;
solver = new ConjugateGradient(10, 0d, false);
final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
solver.solve(a, m, b, null, false);
solver.solve(a, m, b);
}
@Test(expected = NonPositiveDefiniteLinearOperatorException.class)
@ -304,7 +304,7 @@ public class ConjugateGradientTest {
final ArrayRealVector b = new ArrayRealVector(2);
b.setEntry(0, -1d);
b.setEntry(1, -1d);
solver.solve(a, m, b, null, false);
solver.solve(a, m, b);
}
@Test
@ -320,7 +320,7 @@ public class ConjugateGradientTest {
for (int j = 0; j < n; j++) {
b.set(0.);
b.setEntry(j, 1.);
final RealVector x = solver.solve(a, m, b, null, false);
final RealVector x = solver.solve(a, m, b);
for (int i = 0; i < n; i++) {
final double actual = x.getEntry(i);
final double expected = ainv.getEntry(i, j);
@ -364,7 +364,7 @@ public class ConjugateGradientTest {
for (int j = 0; j < n; j++) {
b.set(0.);
b.setEntry(j, 1.);
final RealVector x = solver.solve(a, m, b, null, false);
final RealVector x = solver.solve(a, m, b);
final RealVector y = a.operate(x);
double rnorm = 0.;
for (int i = 0; i < n; i++) {
@ -411,8 +411,8 @@ public class ConjugateGradientTest {
for (int j = 0; j < 1; j++) {
b.set(0.);
b.setEntry(j, 1.);
final RealVector px = pcg.solve(a, m, b, null, false);
final RealVector x = cg.solve(a, b, null, false);
final RealVector px = pcg.solve(a, m, b);
final RealVector x = cg.solve(a, b);
final int npcg = pcg.getIterationManager().getIterations();
final int ncg = cg.getIterationManager().getIterations();
msg = String.format(pattern, npcg, ncg);
@ -465,7 +465,7 @@ public class ConjugateGradientTest {
for (int j = 0; j < n; j++) {
b.set(0.);
b.setEntry(j, 1.);
solver.solve(a, b, null, false);
solver.solve(a, b);
String msg = String.format("column %d (initialization)", j);
Assert.assertEquals(msg, 1, count[0]);
msg = String.format("column %d (iterations started)", j);