Added support for iterative linear solvers (Conjugate Gradient only for now). See MATH-581.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1175404 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Sebastien Brisard 2011-09-25 14:48:18 +00:00
parent c60827b0b7
commit 39eea3de3f
9 changed files with 1391 additions and 0 deletions

View File

@ -0,0 +1,265 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.linear;
import org.apache.commons.math.exception.DimensionMismatchException;
import org.apache.commons.math.exception.MaxCountExceededException;
import org.apache.commons.math.exception.NullArgumentException;
import org.apache.commons.math.exception.util.ExceptionContext;
import org.apache.commons.math.util.IterationManager;
/**
* <p>
* This is an implementation of the conjugate gradient method for
* {@link RealLinearOperator}. It follows closely the template by <a
* href="#BARR1994">Barrett et al. (1994)</a> (figure 2.5). The linear system at
* hand is A &middot; x = b, and the residual is r = b - A &middot; x.
* </p>
* <h3><a id="stopcrit">Default stopping criterion</a></h3>
* <p>
* A default stopping criterion is implemented. The iterations stop when || r ||
* &le; &delta; || b ||, where b is the right-hand side vector, r the current
* estimate of the residual, and &delta; a user-specified tolerance. It should
* be noted that r is the so-called <em>updated</em> residual, which might
* differ from the true residual due to rounding-off errors (see e.g. <a
* href="#STRA2002">Strakos and Tichy, 2002</a>).
* </p>
* <h3>Iteration count</h3>
* <p>
* In the present context, an iteration should be understood as one evaluation
* of the matrix-vector product A &middot; x. The initialization phase therefore
* counts as one iteration.
* </p>
* <h3><a id="context">Exception context</a></h3>
* <p>
* Besides standard {@link DimensionMismatchException}, this class might throw
* {@link NonPositiveDefiniteLinearOperatorException} if the linear operator or
* the preconditioner are not positive definite. In this case, the
* {@link ExceptionContext} provides some more information
* <ul>
* <li>key {@code "operator"} points to the offending linear operator, say L,</li>
* <li>key {@code "vector"} points to the offending vector, say x, such that
* x<sup>T</sup> &middot; L &middot; x < 0.</li>
* </ul>
* </p>
* <h3>References</h3>
* <dl>
* <dt><a id="BARR1994">Barret et al. (1994)</a></dt>
* <dd>R. Barrett, M. Berry, T. F. Chan, J. Demmel, J. M. Donato, J. Dongarra,
* V. Eijkhout, R. Pozo, C. Romine and H. Van der Vorst,
* <em>Templates for the Solution of Linear Systems: Building Blocks for
* Iterative Methods</em>, SIAM</dd>
* <dt><a id="STRA2002">Strakos and Tichy (2002)
* <dt>
* <dd>Z. Strakos and P. Tichy, <a
* href="http://etna.mcs.kent.edu/vol.13.2002/pp56-80.dir/pp56-80.pdf">
* <em>On error estimation in the conjugate gradient method and why it works
* in finite precision computations</em></a>, Electronic Transactions on
* Numerical Analysis 13: 56-80, 2002</dd>
* </dl>
*
* @version $Id$
* @since 3.0
*/
public class ConjugateGradient
extends PreconditionedIterativeLinearSolver {
/**
* The type of all events fired by this implementation of the Conjugate
* Gradient method.
*
* @version $Id$
*/
public abstract static class ConjugateGradientEvent
extends IterativeLinearSolverEvent
implements ProvidesResidual {
/** */
private static final long serialVersionUID = 6461730085343318121L;
/**
* Creates a new instance of this class.
*
* @param source The iterative algorithm on which the event initially
* occurred.
*/
public ConjugateGradientEvent(final Object source) {
super(source);
}
}
/** Key for the <a href="#context">exception context</a>. */
public static final String OPERATOR = "operator";
/** Key for the <a href="#context">exception context</a>. */
public static final String VECTOR = "vector";
/**
* {@code true} if positive-definiteness of matrix and preconditioner should
* be checked.
*/
private boolean check;
/** The value of &delta;, for the default stopping criterion. */
private final double delta;
/**
* Creates a new instance of this class, with <a href="#stopcrit">default
* stopping criterion</a>.
*
* @param maxIterations Maximum number of iterations.
* @param delta &delta; parameter for the default stopping criterion.
* @param check {@code true} if positive definiteness of both matrix and
* preconditioner should be checked.
*/
public ConjugateGradient(final int maxIterations, final double delta,
final boolean check) {
super(maxIterations);
this.delta = delta;
this.check = check;
}
/**
* Creates a new instance of this class, with <a href="#stopcrit">default
* stopping criterion</a> and custom iteration manager.
*
* @param manager Custom iteration manager.
* @param delta &delta; parameter for the default stopping criterion.
* @param check {@code true} if positive definiteness of both matrix and
* preconditioner should be checked.
*/
public ConjugateGradient(final IterationManager manager,
final double delta, final boolean check) {
super(manager);
this.delta = delta;
this.check = check;
}
/**
* Returns {@code true} if positive-definiteness should be checked for both
* matrix and preconditioner.
*
* @return {@code true} if the tests are to be performed.
*/
public final boolean getCheck() {
return check;
}
/** {@inheritDoc} */
@Override
public RealVector solve(final RealLinearOperator a,
final InvertibleRealLinearOperator m,
final RealVector b, final RealVector x0,
final boolean inPlace)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException, MaxCountExceededException {
checkParameters(a, m, b, x0, inPlace);
final IterationManager manager = getIterationManager();
// Initialization of default stopping criterion
manager.resetIterationCount();
final double r2max = delta * delta * b.dotProduct(b);
// p and x are constructed as copies of x0, since presumably, the type
// of x is optimized for the calculation of the matrix-vector product
// A.x.
final RealVector x;
if (inPlace) {
x = x0;
} else {
if (x0 != null) {
x = x0.copy();
} else {
x = new ArrayRealVector(a.getColumnDimension());
}
}
final RealVector p = x.copy();
RealVector q = a.operate(p);
manager.incrementIterationCount();
final RealVector r = b.combine(1, -1, q);
double r2 = r.dotProduct(r);
RealVector z;
if (m == null) {
z = r;
} else {
z = null;
}
final IterativeLinearSolverEvent event;
event = new ConjugateGradientEvent(this) {
public RealVector getResidual() {
return ArrayRealVector.unmodifiableRealVector(r);
}
@Override
public RealVector getRightHandSideVector() {
return ArrayRealVector.unmodifiableRealVector(b);
}
@Override
public RealVector getSolution() {
return ArrayRealVector.unmodifiableRealVector(x);
}
};
manager.fireInitializationEvent(event);
if (r2 <= r2max) {
manager.fireTerminationEvent(event);
return x;
}
double rhoPrev = 0.;
while (true) {
manager.fireIterationStartedEvent(event);
if (m != null) {
z = m.solve(r);
}
final double rhoNext = r.dotProduct(z);
if (check && (rhoNext <= 0.)) {
final NonPositiveDefiniteLinearOperatorException e;
e = new NonPositiveDefiniteLinearOperatorException();
final ExceptionContext context = e.getContext();
context.setValue(OPERATOR, m);
context.setValue(VECTOR, r);
throw e;
}
if (manager.getIterations() == 1) {
p.setSubVector(0, z);
} else {
p.combineToSelf(rhoNext / rhoPrev, 1., z);
}
q = a.operate(p);
manager.incrementIterationCount();
final double pq = p.dotProduct(q);
if (check && (pq <= 0.)) {
final NonPositiveDefiniteLinearOperatorException e;
e = new NonPositiveDefiniteLinearOperatorException();
final ExceptionContext context = e.getContext();
context.setValue(OPERATOR, a);
context.setValue(VECTOR, p);
throw e;
}
final double alpha = rhoNext / pq;
x.combineToSelf(1., alpha, p);
r.combineToSelf(1., -alpha, q);
rhoPrev = rhoNext;
r2 = r.dotProduct(r);
manager.fireIterationPerformedEvent(event);
if (r2 <= r2max) {
manager.fireTerminationEvent(event);
return x;
}
}
}
}

View File

@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.linear;
import org.apache.commons.math.exception.DimensionMismatchException;
import org.apache.commons.math.exception.MaxCountExceededException;
import org.apache.commons.math.exception.NullArgumentException;
import org.apache.commons.math.util.IterationManager;
import org.apache.commons.math.util.MathUtils;
/**
* This abstract class defines an iterative solver for the linear system A
* &middot; x = b. In what follows, the <em>residual</em> r is defined as r = b
* - A &middot; x, where A is the linear operator of the linear system, b is the
* right-hand side vector, and x the current estimate of the solution.
*
* @version $Id$
* @since 3.0
*/
public abstract class IterativeLinearSolver {
/** The object in charge of managing the iterations. */
private final IterationManager manager;
/**
* Creates a new instance of this class, with default iteration manager.
*
* @param maxIterations Maximum number of iterations.
*/
public IterativeLinearSolver(final int maxIterations) {
this.manager = new IterationManager(maxIterations);
}
/**
* Creates a new instance of this class, with custom iteration manager.
*
* @param manager Custom iteration manager.
*/
public IterativeLinearSolver(final IterationManager manager) {
MathUtils.checkNotNull(manager);
this.manager = manager;
}
/**
* Performs all dimension checks on the parameters of
* {@link #solve(RealLinearOperator, RealVector, RealVector, boolean)}, and
* throws an exception if one of the checks fails.
*
* @param a Linear operator A of the system.
* @param b Right-hand side vector.
* @param x0 Initial guess of the solution (can be {@code null} if
* {@code inPlace} is set to {@code false}).
* @param inPlace {@code true} if the initial guess is to be updated with
* the current estimate of the solution.
* @throws NullArgumentException if one of the parameters is {@code null}.
* @throws NonSquareLinearOperatorException if {@code a} is not square.
* @throws DimensionMismatchException if {@code b} or {@code x0} have
* dimensions inconsistent with {@code a}.
*/
protected static void checkParameters(final RealLinearOperator a,
final RealVector b,
final RealVector x0,
final boolean inPlace)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException {
MathUtils.checkNotNull(a);
MathUtils.checkNotNull(b);
if (a.getRowDimension() != a.getColumnDimension()) {
throw new NonSquareLinearOperatorException(a.getRowDimension(),
a.getColumnDimension());
}
if (b.getDimension() != a.getRowDimension()) {
throw new DimensionMismatchException(b.getDimension(),
a.getRowDimension());
}
if (inPlace) {
MathUtils.checkNotNull(x0);
if (x0.getDimension() != a.getColumnDimension()) {
throw new DimensionMismatchException(x0.getDimension(),
a.getColumnDimension());
}
}
}
/**
* Returns the {@link IterationManager} attached to this solver.
*
* @return the manager.
*/
public IterationManager getIterationManager() {
return manager;
}
/**
* Returns an estimate of the solution to the linear system A &middot; x =
* b. If no initial estimate of the solution is provided, (0, &hellip;, 0)
* is assumed.
*
* @param a Linear operator A of the system.
* @param b Right-hand side vector.
* @param x0 Initial guess of the solution (can be {@code null} if
* {@code inPlace} is set to {@code false}).
* @param inPlace {@code true} if the initial guess is to be updated with
* the current estimate of the solution.
* @return A reference to {@code x0} (shallow copy) if {@code inPlace} was
* set to {@code true}. Otherwise, a new vector containing the
* solution.
* @throws NullArgumentException if one of the parameters is {@code null}.
* @throws NonSquareLinearOperatorException if {@code a} is not square.
* @throws DimensionMismatchException if {@code b} or {@code x0} have
* dimensions inconsistent with {@code a}.
* @throws MaxCountExceededException at exhaustion of the iteration count,
* unless a custom {@link MaxCountExceededCallback callback} has
* been set at construction.
*/
public abstract RealVector solve(RealLinearOperator a, RealVector b,
RealVector x0, boolean inPlace)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException, MaxCountExceededException;
}

View File

@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.linear;
import org.apache.commons.math.util.IterationEvent;
/**
* This is the base class for all events occuring during the iterations of a
* {@link IterativeLinearSolver}.
*
* @version $Id$
* @since 3.0
*/
public abstract class IterativeLinearSolverEvent
extends IterationEvent {
/** */
private static final long serialVersionUID = 283291016904748030L;
/**
* Creates a new instance of this class.
*
* @param source The iterative algorithm on which the event initially
* occurred.
*/
public IterativeLinearSolverEvent(final Object source) {
super(source);
}
/**
* Returns the current right-hand side of the linear system to be solved.
* This method should return an unmodifiable view, or a deep copy of the
* actual right-hand side, in order not to compromise subsequent iterations
* of the source {@link IterativeLinearSolver}.
*
* @return The right-hand side vector, b.
*/
public abstract RealVector getRightHandSideVector();
/**
* Returns the current estimate of the solution to the linear system to be
* solved. This method should return an unmodifiable view, or a deep copy of
* the actual current solution, in order not to compromise subsequent
* iterations of the source {@link IterativeLinearSolver}.
*
* @return The solution, x.
*/
public abstract RealVector getSolution();
}

View File

@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.linear;
/**
* This class implements the standard Jacobi (diagonal) preconditioner.
*
* @version $Id$
* @since 3.0
*/
public class JacobiPreconditioner
extends InvertibleRealLinearOperator {
/** The diagonal coefficients of the preconditioner. */
private final ArrayRealVector diag;
/**
* Creates a new instance of this class.
*
* @param diag Diagonal coefficients of the preconditioner.
* @param deep {@code true} if a deep copy of the above array should be
* performed.
*/
public JacobiPreconditioner(final double[] diag, final boolean deep) {
this.diag = new ArrayRealVector(diag, deep);
}
/**
* Creates a new instance of this class. This method extracts the diagonal
* coefficients of the specified linear operator. If {@code a} does not
* extend {@link AbstractRealMatrix}, then the coefficients of the
* underlying matrix are not accessible, coefficient extraction is made by
* matrix-vector products with the basis vectors (and might therefore take
* some time). With matrices, direct entry access is carried out.
*
* @param a Linear operator for which the preconditioner should be built.
* @return Preconditioner made of the diagonal coefficients of the specified
* linear operator.
* @throws NonSquareLinearOperatorException if {@code a} is not square.
*/
public static JacobiPreconditioner create(final RealLinearOperator a)
throws NonSquareLinearOperatorException {
final int n = a.getColumnDimension();
if (a.getRowDimension() != n) {
throw new NonSquareLinearOperatorException(a.getRowDimension(), n);
}
final double[] diag = new double[n];
if (a instanceof AbstractRealMatrix) {
final AbstractRealMatrix m = (AbstractRealMatrix) a;
for (int i = 0; i < n; i++) {
diag[i] = m.getEntry(i, i);
}
} else {
final ArrayRealVector x = new ArrayRealVector(n);
for (int i = 0; i < n; i++) {
x.set(0.);
x.setEntry(i, 1.);
diag[i] = a.operate(x).getEntry(i);
}
}
return new JacobiPreconditioner(diag, false);
}
/** {@inheritDoc} */
@Override
public int getColumnDimension() {
return diag.getDimension();
}
/** {@inheritDoc} */
@Override
public int getRowDimension() {
return diag.getDimension();
}
/** {@inheritDoc} */
@Override
public RealVector operate(final RealVector x) {
// Dimension check is carried out by ebeMultiply
return x.ebeMultiply(diag);
}
/** {@inheritDoc} */
@Override
public RealVector solve(final RealVector b) {
// Dimension check is carried out by ebeDivide
return b.ebeDivide(diag);
}
}

View File

@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.linear;
import org.apache.commons.math.exception.DimensionMismatchException;
import org.apache.commons.math.exception.MaxCountExceededException;
import org.apache.commons.math.exception.NullArgumentException;
import org.apache.commons.math.util.IterationManager;
/**
* This abstract class defines preconditioned iterative solvers. When A is
* ill-conditioned, instead of solving system A &middot; x = b directly, it is
* preferable to solve M<sup>-1</sup> &middot; A &middot; x = M<sup>-1</sup>
* &middot; b, where M approximates in some way A, while remaining comparatively
* easier to invert. M (not M<sup>-1</sup>!) is called the
* <em>preconditionner</em>.
*
* @version $Id$
* @since 3.0
*/
public abstract class PreconditionedIterativeLinearSolver
extends IterativeLinearSolver {
/**
* Creates a new instance of this class, with default iteration manager.
*
* @param maxIterations Maximum number of iterations.
*/
public PreconditionedIterativeLinearSolver(final int maxIterations) {
super(maxIterations);
}
/**
* Creates a new instance of this class, with custom iteration manager.
*
* @param manager Custom iteration manager.
*/
public PreconditionedIterativeLinearSolver(final IterationManager manager) {
super(manager);
}
/**
* Performs all dimension checks on the parameters of
* {@link #solve(RealLinearOperator, InvertibleRealLinearOperator, RealVector, RealVector, boolean)}
* , and throws an exception if one of the checks fails.
*
* @param a Linear operator A of the system.
* @param m Preconditioner (can be {@code null}).
* @param b Right-hand side vector.
* @param x0 Initial guess of the solution (can be {@code null} if
* {@code inPlace} is set to {@code false}).
* @param inPlace {@code true} if the initial guess is to be updated with
* the current estimate of the solution.
* @throws NullArgumentException if one of the parameters is {@code null}.
* @throws NonSquareLinearOperatorException if {@code a} or {@code m} is not
* square.
* @throws DimensionMismatchException if {@code m}, {@code b} or {@code x}
* have dimensions inconsistent with {@code a}.
*/
protected static void checkParameters(final RealLinearOperator a,
final InvertibleRealLinearOperator m,
final RealVector b,
final RealVector x0,
final boolean inPlace)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException {
checkParameters(a, b, x0, inPlace);
if (m != null) {
if (m.getColumnDimension() != m.getRowDimension()) {
throw new NonSquareLinearOperatorException(
m.getColumnDimension(),
m.getRowDimension());
}
if (m.getRowDimension() != a.getRowDimension()) {
throw new DimensionMismatchException(m.getRowDimension(),
a.getRowDimension());
}
}
}
/**
* Returns an estimate of the solution to the linear system A &middot; x =
* b. If no initial estimate of the solution is provided, (0, &hellip;, 0)
* is assumed.
*
* @param a Linear operator A of the system.
* @param m Preconditioner (can be {@code null}).
* @param b Right-hand side vector.
* @param x0 Initial guess of the solution (can be {@code null} if
* {@code inPlace} is set to {@code false}).
* @param inPlace {@code true} if the initial guess is to be updated with
* the current estimate of the solution.
* @return A reference to {@code x0} (shallow copy) if {@code update} was
* set to {@code true}. Otherwise, a new vector containing the
* solution.
* @throws NullArgumentException if one of the parameters is {@code null}.
* @throws NonSquareLinearOperatorException if {@code a} or {@code m} is not
* square.
* @throws DimensionMismatchException if {@code m}, {@code b} or {@code x}
* have dimensions inconsistent with {@code a}.
* @throws MaxCountExceededException at exhaustion of the iteration count,
* unless a custom {@link MaxCountExceededCallback callback} has
* been set at construction.
*/
public abstract RealVector solve(RealLinearOperator a,
InvertibleRealLinearOperator m,
RealVector b, RealVector x0,
final boolean inPlace)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException, MaxCountExceededException;
/** {@inheritDoc} */
@Override
public RealVector solve(final RealLinearOperator a, final RealVector b,
final RealVector x, final boolean inPlace)
throws NullArgumentException, NonSquareLinearOperatorException,
DimensionMismatchException, MaxCountExceededException {
checkParameters(a, b, x, inPlace);
return solve(a, null, b, x, inPlace);
}
}

View File

@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.linear;
/**
* This interface provides access to the current value of the residual of an
* {@link IterativeLinearSolver}. It might be implemented by
* {@link IterativeLinearSolverEvent}, for example.
*
* @version $Id$
* @since 3.0
*/
public interface ProvidesResidual {
/**
* Returns the current value of the residual. This should be an
* unmodifiable view or a deep copy of the residual, in order not to
* compromise the subsequent iterations.
*
* @return the current value of the residual.
*/
RealVector getResidual();
}

View File

@ -0,0 +1,481 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.linear;
import org.apache.commons.math.exception.DimensionMismatchException;
import org.apache.commons.math.exception.MaxCountExceededException;
import org.apache.commons.math.util.IterationEvent;
import org.apache.commons.math.util.IterationListener;
import org.junit.Assert;
import org.junit.Test;
public class ConjugateGradientTest {
@Test(expected = NonSquareLinearOperatorException.class)
public void testNonSquareOperator() {
final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 3);
final IterativeLinearSolver solver;
solver = new ConjugateGradient(10, 0., false);
final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
final ArrayRealVector x = new ArrayRealVector(a.getColumnDimension());
solver.solve(a, b, x, false);
}
@Test(expected = DimensionMismatchException.class)
public void testDimensionMismatchRightHandSide() {
final Array2DRowRealMatrix a = new Array2DRowRealMatrix(3, 3);
final IterativeLinearSolver solver;
solver = new ConjugateGradient(10, 0., false);
final ArrayRealVector b = new ArrayRealVector(2);
final ArrayRealVector x = new ArrayRealVector(3);
solver.solve(a, b, x, false);
}
@Test(expected = DimensionMismatchException.class)
public void testDimensionMismatchSolution() {
final Array2DRowRealMatrix a = new Array2DRowRealMatrix(3, 3);
final IterativeLinearSolver solver;
solver = new ConjugateGradient(10, 0., false);
final ArrayRealVector b = new ArrayRealVector(3);
final ArrayRealVector x = new ArrayRealVector(2);
solver.solve(a, b, x, false);
}
@Test(expected = NonPositiveDefiniteLinearOperatorException.class)
public void testNonPositiveDefiniteLinearOperator() {
final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
a.setEntry(0, 0, -1.);
a.setEntry(0, 1, 2.);
a.setEntry(1, 0, 3.);
a.setEntry(1, 1, 4.);
final IterativeLinearSolver solver;
solver = new ConjugateGradient(10, 0., true);
final ArrayRealVector b = new ArrayRealVector(2);
b.setEntry(0, -1.);
b.setEntry(1, -1.);
final ArrayRealVector x = new ArrayRealVector(2);
solver.solve(a, b, x, false);
}
@Test
public void testUnpreconditionedSolution() {
final int n = 5;
final int maxIterations = 100;
final RealLinearOperator a = new HilbertMatrix(n);
final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
final IterativeLinearSolver solver;
solver = new ConjugateGradient(maxIterations, 1E-10, true);
final RealVector b = new ArrayRealVector(n);
for (int j = 0; j < n; j++) {
b.set(0.);
b.setEntry(j, 1.);
final RealVector x = solver.solve(a, b, null, false);
for (int i = 0; i < n; i++) {
final double actual = x.getEntry(i);
final double expected = ainv.getEntry(i, j);
final double delta = 1E-10 * Math.abs(expected);
final String msg = String.format("entry[%d][%d]", i, j);
Assert.assertEquals(msg, expected, actual, delta);
}
}
}
@Test
public void testUnpreconditionedInPlaceSolutionWithInitialGuess() {
final int n = 5;
final int maxIterations = 100;
final RealLinearOperator a = new HilbertMatrix(n);
final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
final IterativeLinearSolver solver;
solver = new ConjugateGradient(maxIterations, 1E-10, true);
final RealVector b = new ArrayRealVector(n);
for (int j = 0; j < n; j++) {
b.set(0.);
b.setEntry(j, 1.);
final RealVector x0 = new ArrayRealVector(n);
x0.set(1.);
final RealVector x = solver.solve(a, b, x0, true);
Assert.assertSame("x should be a reference to x0", x0, x);
for (int i = 0; i < n; i++) {
final double actual = x.getEntry(i);
final double expected = ainv.getEntry(i, j);
final double delta = 1E-10 * Math.abs(expected);
final String msg = String.format("entry[%d][%d)", i, j);
Assert.assertEquals(msg, expected, actual, delta);
}
}
}
@Test
public void testUnpreconditionedSolutionWithInitialGuess() {
final int n = 5;
final int maxIterations = 100;
final RealLinearOperator a = new HilbertMatrix(n);
final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
final IterativeLinearSolver solver;
solver = new ConjugateGradient(maxIterations, 1E-10, true);
final RealVector b = new ArrayRealVector(n);
for (int j = 0; j < n; j++) {
b.set(0.);
b.setEntry(j, 1.);
final RealVector x0 = new ArrayRealVector(n);
x0.set(1.);
final RealVector x = solver.solve(a, b, x0, false);
Assert.assertNotSame("x should not be a reference to x0", x0, x);
for (int i = 0; i < n; i++) {
final double actual = x.getEntry(i);
final double expected = ainv.getEntry(i, j);
final double delta = 1E-10 * Math.abs(expected);
final String msg = String.format("entry[%d][%d]", i, j);
Assert.assertEquals(msg, expected, actual, delta);
Assert.assertEquals(msg, x0.getEntry(i), 1., Math.ulp(1.));
}
}
}
/**
* Check whether the estimate of the (updated) residual corresponds to the
* exact residual. This fails to be true for a large number of iterations,
* due to the loss of orthogonality of the successive search directions.
* Therefore, in the present test, the number of iterations is limited.
*/
@Test(expected = MaxCountExceededException.class)
public void testUnpreconditionedResidual() {
final int n = 10;
final int maxIterations = n;
final RealLinearOperator a = new HilbertMatrix(n);
final ConjugateGradient solver;
solver = new ConjugateGradient(maxIterations, 1E-15, true);
final RealVector r = new ArrayRealVector(n);
final IterationListener listener = new IterationListener() {
public void terminationPerformed(final IterationEvent e) {
r.setSubVector(0, ((ProvidesResidual) e).getResidual());
}
public void iterationStarted(final IterationEvent e) {
// Do nothing
}
public void iterationPerformed(final IterationEvent e) {
// Do nothing
}
public void initializationPerformed(final IterationEvent e) {
// Do nothing
}
};
solver.getIterationManager().addIterationListener(listener);
final RealVector b = new ArrayRealVector(n);
for (int j = 0; j < n; j++) {
b.set(0.);
b.setEntry(j, 1.);
final RealVector x = solver.solve(a, b, null, false);
final RealVector y = a.operate(x);
for (int i = 0; i < n; i++) {
final double actual = b.getEntry(i) - y.getEntry(i);
final double expected = r.getEntry(i);
final double delta = 1E-6 * Math.abs(expected);
final String msg = String
.format("column %d, residual %d", i, j);
Assert.assertEquals(msg, expected, actual, delta);
}
}
}
@Test(expected = NonSquareLinearOperatorException.class)
public void testNonSquarePreconditioner() {
final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
final InvertibleRealLinearOperator m;
m = new InvertibleRealLinearOperator() {
@Override
public RealVector operate(final RealVector x) {
throw new UnsupportedOperationException();
}
@Override
public int getRowDimension() {
return 2;
}
@Override
public int getColumnDimension() {
return 3;
}
@Override
public RealVector solve(final RealVector b) {
throw new UnsupportedOperationException();
}
};
final PreconditionedIterativeLinearSolver solver;
solver = new ConjugateGradient(10, 0d, false);
final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
solver.solve(a, m, b, null, false);
}
@Test(expected = DimensionMismatchException.class)
public void testMismatchedOperatorDimensions() {
final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
final InvertibleRealLinearOperator m;
m = new InvertibleRealLinearOperator() {
@Override
public RealVector operate(final RealVector x) {
throw new UnsupportedOperationException();
}
@Override
public int getRowDimension() {
return 3;
}
@Override
public int getColumnDimension() {
return 3;
}
@Override
public RealVector solve(final RealVector b) {
throw new UnsupportedOperationException();
}
};
final PreconditionedIterativeLinearSolver solver;
solver = new ConjugateGradient(10, 0d, false);
final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
solver.solve(a, m, b, null, false);
}
@Test(expected = NonPositiveDefiniteLinearOperatorException.class)
public void testNonPositiveDefinitePreconditioner() {
final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
a.setEntry(0, 0, 1d);
a.setEntry(0, 1, 2d);
a.setEntry(1, 0, 3d);
a.setEntry(1, 1, 4d);
final InvertibleRealLinearOperator m;
m = new InvertibleRealLinearOperator() {
@Override
public RealVector operate(final RealVector x) {
final ArrayRealVector y = new ArrayRealVector(2);
y.setEntry(0, -x.getEntry(0));
y.setEntry(1, x.getEntry(1));
return y;
}
@Override
public int getRowDimension() {
return 2;
}
@Override
public int getColumnDimension() {
return 2;
}
@Override
public RealVector solve(final RealVector b) {
final ArrayRealVector x = new ArrayRealVector(2);
x.setEntry(0, -b.getEntry(0));
x.setEntry(1, b.getEntry(1));
return x;
}
};
final PreconditionedIterativeLinearSolver solver;
solver = new ConjugateGradient(10, 0d, true);
final ArrayRealVector b = new ArrayRealVector(2);
b.setEntry(0, -1d);
b.setEntry(1, -1d);
solver.solve(a, m, b, null, false);
}
@Test
public void testPreconditionedSolution() {
final int n = 8;
final int maxIterations = 100;
final RealLinearOperator a = new HilbertMatrix(n);
final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
final InvertibleRealLinearOperator m = JacobiPreconditioner.create(a);
final PreconditionedIterativeLinearSolver solver;
solver = new ConjugateGradient(maxIterations, 1E-15, true);
final RealVector b = new ArrayRealVector(n);
for (int j = 0; j < n; j++) {
b.set(0.);
b.setEntry(j, 1.);
final RealVector x = solver.solve(a, m, b, null, false);
for (int i = 0; i < n; i++) {
final double actual = x.getEntry(i);
final double expected = ainv.getEntry(i, j);
final double delta = 1E-6 * Math.abs(expected);
final String msg = String.format("coefficient (%d, %d)", i, j);
Assert.assertEquals(msg, expected, actual, delta);
}
}
}
@Test(expected = MaxCountExceededException.class)
public void testPreconditionedResidual() {
final int n = 10;
final int maxIterations = n;
final RealLinearOperator a = new HilbertMatrix(n);
final InvertibleRealLinearOperator m = JacobiPreconditioner.create(a);
final ConjugateGradient solver;
solver = new ConjugateGradient(maxIterations, 1E-15, true);
final RealVector r = new ArrayRealVector(n);
final IterationListener listener = new IterationListener() {
public void terminationPerformed(final IterationEvent e) {
r.setSubVector(0, ((ProvidesResidual) e).getResidual());
}
public void iterationStarted(final IterationEvent e) {
// Do nothing
}
public void iterationPerformed(final IterationEvent e) {
// Do nothing
}
public void initializationPerformed(final IterationEvent e) {
// Do nothing
}
};
solver.getIterationManager().addIterationListener(listener);
final RealVector b = new ArrayRealVector(n);
for (int j = 0; j < n; j++) {
b.set(0.);
b.setEntry(j, 1.);
final RealVector x = solver.solve(a, m, b, null, false);
final RealVector y = a.operate(x);
double rnorm = 0.;
for (int i = 0; i < n; i++) {
final double actual = b.getEntry(i) - y.getEntry(i);
final double expected = r.getEntry(i);
final double delta = 1E-6 * Math.abs(expected);
final String msg = String
.format("column %d, residual %d", i, j);
Assert.assertEquals(msg, expected, actual, delta);
}
rnorm = r.getNorm();
Assert.assertEquals("norm of residual", rnorm, r.getNorm(),
1E-6 * Math.abs(rnorm));
}
}
@Test
public void testPreconditionedSolution2() {
final int n = 100;
final int maxIterations = 100000;
final Array2DRowRealMatrix a = new Array2DRowRealMatrix(n, n);
double daux = 1.;
for (int i = 0; i < n; i++) {
a.setEntry(i, i, daux);
daux *= 1.2;
for (int j = i + 1; j < n; j++) {
if (i == j) {
} else {
final double value = 1.0;
a.setEntry(i, j, value);
a.setEntry(j, i, value);
}
}
}
final InvertibleRealLinearOperator m = JacobiPreconditioner.create(a);
final PreconditionedIterativeLinearSolver pcg;
final IterativeLinearSolver cg;
pcg = new ConjugateGradient(maxIterations, 1E-6, true);
cg = new ConjugateGradient(maxIterations, 1E-6, true);
final RealVector b = new ArrayRealVector(n);
final String pattern = "preconditioned gradient (%d iterations) should"
+ " have been faster than unpreconditioned (%d iterations)";
String msg;
for (int j = 0; j < 1; j++) {
b.set(0.);
b.setEntry(j, 1.);
final RealVector px = pcg.solve(a, m, b, null, false);
final RealVector x = cg.solve(a, b, null, false);
final int npcg = pcg.getIterationManager().getIterations();
final int ncg = cg.getIterationManager().getIterations();
msg = String.format(pattern, npcg, ncg);
Assert.assertTrue(msg, npcg < ncg);
for (int i = 0; i < n; i++) {
msg = String.format("row %d, column %d", i, j);
final double expected = x.getEntry(i);
final double actual = px.getEntry(i);
final double delta = 1E-6 * Math.abs(expected);
Assert.assertEquals(msg, expected, actual, delta);
}
}
}
@Test
public void testEventManagement() {
final int n = 5;
final int maxIterations = 100;
final RealLinearOperator a = new HilbertMatrix(n);
final IterativeLinearSolver solver;
final int[] count = new int[] {
0, 0, 0, 0
};
final IterationListener listener = new IterationListener() {
public void initializationPerformed(final IterationEvent e) {
count[0] = 1;
count[1] = 0;
count[2] = 0;
count[3] = 0;
}
public void iterationPerformed(final IterationEvent e) {
++count[2];
}
public void iterationStarted(IterationEvent e) {
++count[1];
}
public void terminationPerformed(final IterationEvent e) {
++count[3];
}
};
solver = new ConjugateGradient(maxIterations, 1E-10, true);
solver.getIterationManager().addIterationListener(listener);
final RealVector b = new ArrayRealVector(n);
for (int j = 0; j < n; j++) {
b.set(0.);
b.setEntry(j, 1.);
solver.solve(a, b, null, false);
String msg = String.format("column %d (initialization)", j);
Assert.assertEquals(msg, 1, count[0]);
msg = String.format("column %d (iterations started)", j);
Assert.assertEquals(msg, solver.getIterationManager()
.getIterations() - 1, count[1]);
msg = String.format("column %d (iterations performed)", j);
Assert.assertEquals(msg, solver.getIterationManager()
.getIterations() - 1, count[2]);
msg = String.format("column %d (finalization)", j);
Assert.assertEquals(msg, 1, count[3]);
}
}
}

View File

@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.linear;
import org.apache.commons.math.exception.DimensionMismatchException;
/** This class implements Hilbert Matrices as {@link RealLinearOperator}. */
public class HilbertMatrix
extends RealLinearOperator {
/** The size of the matrix. */
private final int n;
/**
* Creates a new instance of this class.
*
* @param n Size of the matrix to be created..
*/
public HilbertMatrix(final int n) {
this.n = n;
}
/** {@inheritDoc} */
@Override
public int getColumnDimension() {
return n;
}
/** {@inheritDoc} */
@Override
public int getRowDimension() {
return n;
}
/** {@inheritDoc} */
@Override
public RealVector operate(final RealVector x) {
if (x.getDimension() != n) {
throw new DimensionMismatchException(x.getDimension(), n);
}
final double[] y = new double[n];
for (int i = 0; i < n; i++) {
double pos = 0.;
double neg = 0.;
for (int j = 0; j < n; j++) {
final double xj = x.getEntry(j);
final double coeff = 1. / (i + j + 1.);
// Positive and negative values are sorted out in order to limit
// catastrophic cancellations (do not forget that Hilbert
// matrices are *very* ill-conditioned!
if (xj > 0.) {
pos += coeff * xj;
} else {
neg += coeff * xj;
}
}
y[i] = pos + neg;
}
return new ArrayRealVector(y, false);
}
}

View File

@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math.linear;
import org.apache.commons.math.exception.DimensionMismatchException;
import org.apache.commons.math.util.MathUtils;
/**
* This class implements inverses of Hilbert Matrices as
* {@link RealLinearOperator}.
*/
public class InverseHilbertMatrix
extends RealLinearOperator {
/** The size of the matrix. */
private final int n;
/**
* Creates a new instance of this class.
*
* @param n Size of the matrix to be created.
*/
public InverseHilbertMatrix(final int n) {
this.n = n;
}
/** {@inheritDoc} */
@Override
public int getColumnDimension() {
return n;
}
/**
* Returns the {@code (i, j)} entry of the inverse Hilbert matrix. Exact
* arithmetic is used; in case of overflow, an exception is thrown.
*
* @param i Row index (starts at 0).
* @param j Column index (starts at 0).
* @return The coefficient of the inverse Hilbert matrix.
*/
public long getEntry(final int i, final int j) {
long val = i + j + 1;
long aux = MathUtils.binomialCoefficient(n + i, n - j - 1);
val = MathUtils.mulAndCheck(val, aux);
aux = MathUtils.binomialCoefficient(n + j, n - i - 1);
val = MathUtils.mulAndCheck(val, aux);
aux = MathUtils.binomialCoefficient(i + j, i);
val = MathUtils.mulAndCheck(val, aux);
val = MathUtils.mulAndCheck(val, aux);
return ((i + j) & 1) == 0 ? val : -val;
}
/** {@inheritDoc} */
@Override
public int getRowDimension() {
return n;
}
/** {@inheritDoc} */
@Override
public RealVector operate(final RealVector x) {
if (x.getDimension() != n) {
throw new DimensionMismatchException(x.getDimension(), n);
}
final double[] y = new double[n];
for (int i = 0; i < n; i++) {
double pos = 0.;
double neg = 0.;
for (int j = 0; j < n; j++) {
final double xj = x.getEntry(j);
final long coeff = getEntry(i, j);
final double daux = coeff * xj;
// Positive and negative values are sorted out in order to limit
// catastrophic cancellations (do not forget that Hilbert
// matrices are *very* ill-conditioned!
if (daux > 0.) {
pos += daux;
} else {
neg += daux;
}
}
y[i] = pos + neg;
}
return new ArrayRealVector(y, false);
}
}