Modifications to the ConjugateGradient class and unit tests
- altered the way iterations are counted: Incrementor is incremented prior to any modification to the current state, so that the solver is in a consistent state (accessible residual corresponds to the last estimate of the solution), even in case of MaxCountExceededException occuring. - modified some tests which were not testing anything. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1179488 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
8fcbe82ab6
commit
61018c7997
|
@ -174,13 +174,15 @@ public class ConjugateGradient
|
|||
manager.resetIterationCount();
|
||||
final double r2max = delta * delta * b.dotProduct(b);
|
||||
|
||||
// Initialization phase counts as one iteration.
|
||||
manager.incrementIterationCount();
|
||||
// p and x are constructed as copies of x0, since presumably, the type
|
||||
// of x is optimized for the calculation of the matrix-vector product
|
||||
// A.x.
|
||||
final RealVector x = x0;
|
||||
final RealVector p = x.copy();
|
||||
RealVector q = a.operate(p);
|
||||
manager.incrementIterationCount();
|
||||
|
||||
final RealVector r = b.combine(1, -1, q);
|
||||
double r2 = r.dotProduct(r);
|
||||
RealVector z;
|
||||
|
@ -213,6 +215,7 @@ public class ConjugateGradient
|
|||
}
|
||||
double rhoPrev = 0.;
|
||||
while (true) {
|
||||
manager.incrementIterationCount();
|
||||
manager.fireIterationStartedEvent(event);
|
||||
if (m != null) {
|
||||
z = m.solve(r);
|
||||
|
@ -226,13 +229,12 @@ public class ConjugateGradient
|
|||
context.setValue(VECTOR, r);
|
||||
throw e;
|
||||
}
|
||||
if (manager.getIterations() == 1) {
|
||||
if (manager.getIterations() == 2) {
|
||||
p.setSubVector(0, z);
|
||||
} else {
|
||||
p.combineToSelf(rhoNext / rhoPrev, 1., z);
|
||||
}
|
||||
q = a.operate(p);
|
||||
manager.incrementIterationCount();
|
||||
final double pq = p.dotProduct(q);
|
||||
if (check && (pq <= 0.)) {
|
||||
final NonPositiveDefiniteOperatorException e;
|
||||
|
|
|
@ -153,7 +153,7 @@ public class ConjugateGradientTest {
|
|||
* due to the loss of orthogonality of the successive search directions.
|
||||
* Therefore, in the present test, the number of iterations is limited.
|
||||
*/
|
||||
@Test(expected = MaxCountExceededException.class)
|
||||
@Test
|
||||
public void testUnpreconditionedResidual() {
|
||||
final int n = 10;
|
||||
final int maxIterations = n;
|
||||
|
@ -161,10 +161,11 @@ public class ConjugateGradientTest {
|
|||
final ConjugateGradient solver;
|
||||
solver = new ConjugateGradient(maxIterations, 1E-15, true);
|
||||
final RealVector r = new ArrayRealVector(n);
|
||||
final RealVector x = new ArrayRealVector(n);
|
||||
final IterationListener listener = new IterationListener() {
|
||||
|
||||
public void terminationPerformed(final IterationEvent e) {
|
||||
r.setSubVector(0, ((ProvidesResidual) e).getResidual());
|
||||
// Do nothing
|
||||
}
|
||||
|
||||
public void iterationStarted(final IterationEvent e) {
|
||||
|
@ -172,7 +173,10 @@ public class ConjugateGradientTest {
|
|||
}
|
||||
|
||||
public void iterationPerformed(final IterationEvent e) {
|
||||
// Do nothing
|
||||
RealVector v = ((ProvidesResidual)e).getResidual();
|
||||
r.setSubVector(0, v);
|
||||
v = ((IterativeLinearSolverEvent) e).getSolution();
|
||||
x.setSubVector(0, v);
|
||||
}
|
||||
|
||||
public void initializationPerformed(final IterationEvent e) {
|
||||
|
@ -180,22 +184,29 @@ public class ConjugateGradientTest {
|
|||
}
|
||||
};
|
||||
solver.getIterationManager().addIterationListener(listener);
|
||||
|
||||
final RealVector b = new ArrayRealVector(n);
|
||||
for (int j = 0; j < n; j++) {
|
||||
b.set(0.);
|
||||
b.setEntry(j, 1.);
|
||||
|
||||
final RealVector x = solver.solve(a, b);
|
||||
final RealVector y = a.operate(x);
|
||||
for (int i = 0; i < n; i++) {
|
||||
final double actual = b.getEntry(i) - y.getEntry(i);
|
||||
final double expected = r.getEntry(i);
|
||||
final double delta = 1E-6 * Math.abs(expected);
|
||||
final String msg = String
|
||||
.format("column %d, residual %d", i, j);
|
||||
Assert.assertEquals(msg, expected, actual, delta);
|
||||
boolean caught = false;
|
||||
try {
|
||||
solver.solve(a, b);
|
||||
} catch (MaxCountExceededException e) {
|
||||
caught = true;
|
||||
final RealVector y = a.operate(x);
|
||||
for (int i = 0; i < n; i++) {
|
||||
final double actual = b.getEntry(i) - y.getEntry(i);
|
||||
final double expected = r.getEntry(i);
|
||||
final double delta = 1E-6 * Math.abs(expected);
|
||||
final String msg = String
|
||||
.format("column %d, residual %d", i, j);
|
||||
Assert.assertEquals(msg, expected, actual, delta);
|
||||
}
|
||||
}
|
||||
Assert
|
||||
.assertTrue("MaxCountExceededException should have been caught",
|
||||
caught);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -331,7 +342,7 @@ public class ConjugateGradientTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test(expected = MaxCountExceededException.class)
|
||||
@Test
|
||||
public void testPreconditionedResidual() {
|
||||
final int n = 10;
|
||||
final int maxIterations = n;
|
||||
|
@ -340,10 +351,11 @@ public class ConjugateGradientTest {
|
|||
final ConjugateGradient solver;
|
||||
solver = new ConjugateGradient(maxIterations, 1E-15, true);
|
||||
final RealVector r = new ArrayRealVector(n);
|
||||
final RealVector x = new ArrayRealVector(n);
|
||||
final IterationListener listener = new IterationListener() {
|
||||
|
||||
public void terminationPerformed(final IterationEvent e) {
|
||||
r.setSubVector(0, ((ProvidesResidual) e).getResidual());
|
||||
// Do nothing
|
||||
}
|
||||
|
||||
public void iterationStarted(final IterationEvent e) {
|
||||
|
@ -351,7 +363,10 @@ public class ConjugateGradientTest {
|
|||
}
|
||||
|
||||
public void iterationPerformed(final IterationEvent e) {
|
||||
// Do nothing
|
||||
RealVector v = ((ProvidesResidual)e).getResidual();
|
||||
r.setSubVector(0, v);
|
||||
v = ((IterativeLinearSolverEvent) e).getSolution();
|
||||
x.setSubVector(0, v);
|
||||
}
|
||||
|
||||
public void initializationPerformed(final IterationEvent e) {
|
||||
|
@ -364,20 +379,25 @@ public class ConjugateGradientTest {
|
|||
for (int j = 0; j < n; j++) {
|
||||
b.set(0.);
|
||||
b.setEntry(j, 1.);
|
||||
final RealVector x = solver.solve(a, m, b);
|
||||
final RealVector y = a.operate(x);
|
||||
double rnorm = 0.;
|
||||
for (int i = 0; i < n; i++) {
|
||||
final double actual = b.getEntry(i) - y.getEntry(i);
|
||||
final double expected = r.getEntry(i);
|
||||
final double delta = 1E-6 * Math.abs(expected);
|
||||
final String msg = String
|
||||
.format("column %d, residual %d", i, j);
|
||||
Assert.assertEquals(msg, expected, actual, delta);
|
||||
|
||||
boolean caught = false;
|
||||
try {
|
||||
solver.solve(a, m, b);
|
||||
} catch (MaxCountExceededException e) {
|
||||
caught = true;
|
||||
final RealVector y = a.operate(x);
|
||||
for (int i = 0; i < n; i++) {
|
||||
final double actual = b.getEntry(i) - y.getEntry(i);
|
||||
final double expected = r.getEntry(i);
|
||||
final double delta = 1E-6 * Math.abs(expected);
|
||||
final String msg = String
|
||||
.format("column %d, residual %d", i, j);
|
||||
Assert.assertEquals(msg, expected, actual, delta);
|
||||
}
|
||||
}
|
||||
rnorm = r.getNorm();
|
||||
Assert.assertEquals("norm of residual", rnorm, r.getNorm(),
|
||||
1E-6 * Math.abs(rnorm));
|
||||
Assert
|
||||
.assertTrue("MaxCountExceededException should have been caught",
|
||||
caught);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue