Modifications to the ConjugateGradient class and unit tests

- altered the way iterations are counted: Incrementor is incremented prior to any modification to the current state, so that the solver is in a consistent state (accessible residual corresponds to the last estimate of the solution), even in case of MaxCountExceededException occuring.
  - modified some tests which were not testing anything.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1179488 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Sebastien Brisard 2011-10-06 02:14:20 +00:00
parent 8fcbe82ab6
commit 61018c7997
2 changed files with 54 additions and 32 deletions

View File

@ -174,13 +174,15 @@ public class ConjugateGradient
manager.resetIterationCount();
final double r2max = delta * delta * b.dotProduct(b);
// Initialization phase counts as one iteration.
manager.incrementIterationCount();
// p and x are constructed as copies of x0, since presumably, the type
// of x is optimized for the calculation of the matrix-vector product
// A.x.
final RealVector x = x0;
final RealVector p = x.copy();
RealVector q = a.operate(p);
manager.incrementIterationCount();
final RealVector r = b.combine(1, -1, q);
double r2 = r.dotProduct(r);
RealVector z;
@ -213,6 +215,7 @@ public class ConjugateGradient
}
double rhoPrev = 0.;
while (true) {
manager.incrementIterationCount();
manager.fireIterationStartedEvent(event);
if (m != null) {
z = m.solve(r);
@ -226,13 +229,12 @@ public class ConjugateGradient
context.setValue(VECTOR, r);
throw e;
}
if (manager.getIterations() == 1) {
if (manager.getIterations() == 2) {
p.setSubVector(0, z);
} else {
p.combineToSelf(rhoNext / rhoPrev, 1., z);
}
q = a.operate(p);
manager.incrementIterationCount();
final double pq = p.dotProduct(q);
if (check && (pq <= 0.)) {
final NonPositiveDefiniteOperatorException e;

View File

@ -153,7 +153,7 @@ public class ConjugateGradientTest {
* due to the loss of orthogonality of the successive search directions.
* Therefore, in the present test, the number of iterations is limited.
*/
@Test(expected = MaxCountExceededException.class)
@Test
public void testUnpreconditionedResidual() {
final int n = 10;
final int maxIterations = n;
@ -161,10 +161,11 @@ public class ConjugateGradientTest {
final ConjugateGradient solver;
solver = new ConjugateGradient(maxIterations, 1E-15, true);
final RealVector r = new ArrayRealVector(n);
final RealVector x = new ArrayRealVector(n);
final IterationListener listener = new IterationListener() {
public void terminationPerformed(final IterationEvent e) {
r.setSubVector(0, ((ProvidesResidual) e).getResidual());
// Do nothing
}
public void iterationStarted(final IterationEvent e) {
@ -172,7 +173,10 @@ public class ConjugateGradientTest {
}
public void iterationPerformed(final IterationEvent e) {
// Do nothing
RealVector v = ((ProvidesResidual)e).getResidual();
r.setSubVector(0, v);
v = ((IterativeLinearSolverEvent) e).getSolution();
x.setSubVector(0, v);
}
public void initializationPerformed(final IterationEvent e) {
@ -180,22 +184,29 @@ public class ConjugateGradientTest {
}
};
solver.getIterationManager().addIterationListener(listener);
final RealVector b = new ArrayRealVector(n);
for (int j = 0; j < n; j++) {
b.set(0.);
b.setEntry(j, 1.);
final RealVector x = solver.solve(a, b);
final RealVector y = a.operate(x);
for (int i = 0; i < n; i++) {
final double actual = b.getEntry(i) - y.getEntry(i);
final double expected = r.getEntry(i);
final double delta = 1E-6 * Math.abs(expected);
final String msg = String
.format("column %d, residual %d", i, j);
Assert.assertEquals(msg, expected, actual, delta);
boolean caught = false;
try {
solver.solve(a, b);
} catch (MaxCountExceededException e) {
caught = true;
final RealVector y = a.operate(x);
for (int i = 0; i < n; i++) {
final double actual = b.getEntry(i) - y.getEntry(i);
final double expected = r.getEntry(i);
final double delta = 1E-6 * Math.abs(expected);
final String msg = String
.format("column %d, residual %d", i, j);
Assert.assertEquals(msg, expected, actual, delta);
}
}
Assert
.assertTrue("MaxCountExceededException should have been caught",
caught);
}
}
@ -331,7 +342,7 @@ public class ConjugateGradientTest {
}
}
@Test(expected = MaxCountExceededException.class)
@Test
public void testPreconditionedResidual() {
final int n = 10;
final int maxIterations = n;
@ -340,10 +351,11 @@ public class ConjugateGradientTest {
final ConjugateGradient solver;
solver = new ConjugateGradient(maxIterations, 1E-15, true);
final RealVector r = new ArrayRealVector(n);
final RealVector x = new ArrayRealVector(n);
final IterationListener listener = new IterationListener() {
public void terminationPerformed(final IterationEvent e) {
r.setSubVector(0, ((ProvidesResidual) e).getResidual());
// Do nothing
}
public void iterationStarted(final IterationEvent e) {
@ -351,7 +363,10 @@ public class ConjugateGradientTest {
}
public void iterationPerformed(final IterationEvent e) {
// Do nothing
RealVector v = ((ProvidesResidual)e).getResidual();
r.setSubVector(0, v);
v = ((IterativeLinearSolverEvent) e).getSolution();
x.setSubVector(0, v);
}
public void initializationPerformed(final IterationEvent e) {
@ -364,20 +379,25 @@ public class ConjugateGradientTest {
for (int j = 0; j < n; j++) {
b.set(0.);
b.setEntry(j, 1.);
final RealVector x = solver.solve(a, m, b);
final RealVector y = a.operate(x);
double rnorm = 0.;
for (int i = 0; i < n; i++) {
final double actual = b.getEntry(i) - y.getEntry(i);
final double expected = r.getEntry(i);
final double delta = 1E-6 * Math.abs(expected);
final String msg = String
.format("column %d, residual %d", i, j);
Assert.assertEquals(msg, expected, actual, delta);
boolean caught = false;
try {
solver.solve(a, m, b);
} catch (MaxCountExceededException e) {
caught = true;
final RealVector y = a.operate(x);
for (int i = 0; i < n; i++) {
final double actual = b.getEntry(i) - y.getEntry(i);
final double expected = r.getEntry(i);
final double delta = 1E-6 * Math.abs(expected);
final String msg = String
.format("column %d, residual %d", i, j);
Assert.assertEquals(msg, expected, actual, delta);
}
}
rnorm = r.getNorm();
Assert.assertEquals("norm of residual", rnorm, r.getNorm(),
1E-6 * Math.abs(rnorm));
Assert
.assertTrue("MaxCountExceededException should have been caught",
caught);
}
}