Modifications to the ConjugateGradient class and unit tests

- altered the way iterations are counted: Incrementor is incremented prior to any modification to the current state, so that the solver is in a consistent state (accessible residual corresponds to the last estimate of the solution), even in case of MaxCountExceededException occuring.
  - modified some tests which were not testing anything.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1179488 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Sebastien Brisard 2011-10-06 02:14:20 +00:00
parent 8fcbe82ab6
commit 61018c7997
2 changed files with 54 additions and 32 deletions

View File

@ -174,13 +174,15 @@ public class ConjugateGradient
manager.resetIterationCount(); manager.resetIterationCount();
final double r2max = delta * delta * b.dotProduct(b); final double r2max = delta * delta * b.dotProduct(b);
// Initialization phase counts as one iteration.
manager.incrementIterationCount();
// p and x are constructed as copies of x0, since presumably, the type // p and x are constructed as copies of x0, since presumably, the type
// of x is optimized for the calculation of the matrix-vector product // of x is optimized for the calculation of the matrix-vector product
// A.x. // A.x.
final RealVector x = x0; final RealVector x = x0;
final RealVector p = x.copy(); final RealVector p = x.copy();
RealVector q = a.operate(p); RealVector q = a.operate(p);
manager.incrementIterationCount();
final RealVector r = b.combine(1, -1, q); final RealVector r = b.combine(1, -1, q);
double r2 = r.dotProduct(r); double r2 = r.dotProduct(r);
RealVector z; RealVector z;
@ -213,6 +215,7 @@ public class ConjugateGradient
} }
double rhoPrev = 0.; double rhoPrev = 0.;
while (true) { while (true) {
manager.incrementIterationCount();
manager.fireIterationStartedEvent(event); manager.fireIterationStartedEvent(event);
if (m != null) { if (m != null) {
z = m.solve(r); z = m.solve(r);
@ -226,13 +229,12 @@ public class ConjugateGradient
context.setValue(VECTOR, r); context.setValue(VECTOR, r);
throw e; throw e;
} }
if (manager.getIterations() == 1) { if (manager.getIterations() == 2) {
p.setSubVector(0, z); p.setSubVector(0, z);
} else { } else {
p.combineToSelf(rhoNext / rhoPrev, 1., z); p.combineToSelf(rhoNext / rhoPrev, 1., z);
} }
q = a.operate(p); q = a.operate(p);
manager.incrementIterationCount();
final double pq = p.dotProduct(q); final double pq = p.dotProduct(q);
if (check && (pq <= 0.)) { if (check && (pq <= 0.)) {
final NonPositiveDefiniteOperatorException e; final NonPositiveDefiniteOperatorException e;

View File

@ -153,7 +153,7 @@ public class ConjugateGradientTest {
* due to the loss of orthogonality of the successive search directions. * due to the loss of orthogonality of the successive search directions.
* Therefore, in the present test, the number of iterations is limited. * Therefore, in the present test, the number of iterations is limited.
*/ */
@Test(expected = MaxCountExceededException.class) @Test
public void testUnpreconditionedResidual() { public void testUnpreconditionedResidual() {
final int n = 10; final int n = 10;
final int maxIterations = n; final int maxIterations = n;
@ -161,10 +161,11 @@ public class ConjugateGradientTest {
final ConjugateGradient solver; final ConjugateGradient solver;
solver = new ConjugateGradient(maxIterations, 1E-15, true); solver = new ConjugateGradient(maxIterations, 1E-15, true);
final RealVector r = new ArrayRealVector(n); final RealVector r = new ArrayRealVector(n);
final RealVector x = new ArrayRealVector(n);
final IterationListener listener = new IterationListener() { final IterationListener listener = new IterationListener() {
public void terminationPerformed(final IterationEvent e) { public void terminationPerformed(final IterationEvent e) {
r.setSubVector(0, ((ProvidesResidual) e).getResidual()); // Do nothing
} }
public void iterationStarted(final IterationEvent e) { public void iterationStarted(final IterationEvent e) {
@ -172,7 +173,10 @@ public class ConjugateGradientTest {
} }
public void iterationPerformed(final IterationEvent e) { public void iterationPerformed(final IterationEvent e) {
// Do nothing RealVector v = ((ProvidesResidual)e).getResidual();
r.setSubVector(0, v);
v = ((IterativeLinearSolverEvent) e).getSolution();
x.setSubVector(0, v);
} }
public void initializationPerformed(final IterationEvent e) { public void initializationPerformed(final IterationEvent e) {
@ -180,22 +184,29 @@ public class ConjugateGradientTest {
} }
}; };
solver.getIterationManager().addIterationListener(listener); solver.getIterationManager().addIterationListener(listener);
final RealVector b = new ArrayRealVector(n); final RealVector b = new ArrayRealVector(n);
for (int j = 0; j < n; j++) { for (int j = 0; j < n; j++) {
b.set(0.); b.set(0.);
b.setEntry(j, 1.); b.setEntry(j, 1.);
final RealVector x = solver.solve(a, b); boolean caught = false;
final RealVector y = a.operate(x); try {
for (int i = 0; i < n; i++) { solver.solve(a, b);
final double actual = b.getEntry(i) - y.getEntry(i); } catch (MaxCountExceededException e) {
final double expected = r.getEntry(i); caught = true;
final double delta = 1E-6 * Math.abs(expected); final RealVector y = a.operate(x);
final String msg = String for (int i = 0; i < n; i++) {
.format("column %d, residual %d", i, j); final double actual = b.getEntry(i) - y.getEntry(i);
Assert.assertEquals(msg, expected, actual, delta); final double expected = r.getEntry(i);
final double delta = 1E-6 * Math.abs(expected);
final String msg = String
.format("column %d, residual %d", i, j);
Assert.assertEquals(msg, expected, actual, delta);
}
} }
Assert
.assertTrue("MaxCountExceededException should have been caught",
caught);
} }
} }
@ -331,7 +342,7 @@ public class ConjugateGradientTest {
} }
} }
@Test(expected = MaxCountExceededException.class) @Test
public void testPreconditionedResidual() { public void testPreconditionedResidual() {
final int n = 10; final int n = 10;
final int maxIterations = n; final int maxIterations = n;
@ -340,10 +351,11 @@ public class ConjugateGradientTest {
final ConjugateGradient solver; final ConjugateGradient solver;
solver = new ConjugateGradient(maxIterations, 1E-15, true); solver = new ConjugateGradient(maxIterations, 1E-15, true);
final RealVector r = new ArrayRealVector(n); final RealVector r = new ArrayRealVector(n);
final RealVector x = new ArrayRealVector(n);
final IterationListener listener = new IterationListener() { final IterationListener listener = new IterationListener() {
public void terminationPerformed(final IterationEvent e) { public void terminationPerformed(final IterationEvent e) {
r.setSubVector(0, ((ProvidesResidual) e).getResidual()); // Do nothing
} }
public void iterationStarted(final IterationEvent e) { public void iterationStarted(final IterationEvent e) {
@ -351,7 +363,10 @@ public class ConjugateGradientTest {
} }
public void iterationPerformed(final IterationEvent e) { public void iterationPerformed(final IterationEvent e) {
// Do nothing RealVector v = ((ProvidesResidual)e).getResidual();
r.setSubVector(0, v);
v = ((IterativeLinearSolverEvent) e).getSolution();
x.setSubVector(0, v);
} }
public void initializationPerformed(final IterationEvent e) { public void initializationPerformed(final IterationEvent e) {
@ -364,20 +379,25 @@ public class ConjugateGradientTest {
for (int j = 0; j < n; j++) { for (int j = 0; j < n; j++) {
b.set(0.); b.set(0.);
b.setEntry(j, 1.); b.setEntry(j, 1.);
final RealVector x = solver.solve(a, m, b);
final RealVector y = a.operate(x); boolean caught = false;
double rnorm = 0.; try {
for (int i = 0; i < n; i++) { solver.solve(a, m, b);
final double actual = b.getEntry(i) - y.getEntry(i); } catch (MaxCountExceededException e) {
final double expected = r.getEntry(i); caught = true;
final double delta = 1E-6 * Math.abs(expected); final RealVector y = a.operate(x);
final String msg = String for (int i = 0; i < n; i++) {
.format("column %d, residual %d", i, j); final double actual = b.getEntry(i) - y.getEntry(i);
Assert.assertEquals(msg, expected, actual, delta); final double expected = r.getEntry(i);
final double delta = 1E-6 * Math.abs(expected);
final String msg = String
.format("column %d, residual %d", i, j);
Assert.assertEquals(msg, expected, actual, delta);
}
} }
rnorm = r.getNorm(); Assert
Assert.assertEquals("norm of residual", rnorm, r.getNorm(), .assertTrue("MaxCountExceededException should have been caught",
1E-6 * Math.abs(rnorm)); caught);
} }
} }