diff --git a/src/main/java/org/apache/commons/math/linear/ConjugateGradient.java b/src/main/java/org/apache/commons/math/linear/ConjugateGradient.java index 55553a0df..6dfb4e63b 100644 --- a/src/main/java/org/apache/commons/math/linear/ConjugateGradient.java +++ b/src/main/java/org/apache/commons/math/linear/ConjugateGradient.java @@ -174,13 +174,15 @@ public class ConjugateGradient manager.resetIterationCount(); final double r2max = delta * delta * b.dotProduct(b); + // Initialization phase counts as one iteration. + manager.incrementIterationCount(); // p and x are constructed as copies of x0, since presumably, the type // of x is optimized for the calculation of the matrix-vector product // A.x. final RealVector x = x0; final RealVector p = x.copy(); RealVector q = a.operate(p); - manager.incrementIterationCount(); + final RealVector r = b.combine(1, -1, q); double r2 = r.dotProduct(r); RealVector z; @@ -213,6 +215,7 @@ public class ConjugateGradient } double rhoPrev = 0.; while (true) { + manager.incrementIterationCount(); manager.fireIterationStartedEvent(event); if (m != null) { z = m.solve(r); @@ -226,13 +229,12 @@ public class ConjugateGradient context.setValue(VECTOR, r); throw e; } - if (manager.getIterations() == 1) { + if (manager.getIterations() == 2) { p.setSubVector(0, z); } else { p.combineToSelf(rhoNext / rhoPrev, 1., z); } q = a.operate(p); - manager.incrementIterationCount(); final double pq = p.dotProduct(q); if (check && (pq <= 0.)) { final NonPositiveDefiniteOperatorException e; diff --git a/src/test/java/org/apache/commons/math/linear/ConjugateGradientTest.java b/src/test/java/org/apache/commons/math/linear/ConjugateGradientTest.java index 9e4845978..8f313a150 100644 --- a/src/test/java/org/apache/commons/math/linear/ConjugateGradientTest.java +++ b/src/test/java/org/apache/commons/math/linear/ConjugateGradientTest.java @@ -153,7 +153,7 @@ public class ConjugateGradientTest { * due to the loss of orthogonality of the successive search directions. * Therefore, in the present test, the number of iterations is limited. */ - @Test(expected = MaxCountExceededException.class) + @Test public void testUnpreconditionedResidual() { final int n = 10; final int maxIterations = n; @@ -161,10 +161,11 @@ public class ConjugateGradientTest { final ConjugateGradient solver; solver = new ConjugateGradient(maxIterations, 1E-15, true); final RealVector r = new ArrayRealVector(n); + final RealVector x = new ArrayRealVector(n); final IterationListener listener = new IterationListener() { public void terminationPerformed(final IterationEvent e) { - r.setSubVector(0, ((ProvidesResidual) e).getResidual()); + // Do nothing } public void iterationStarted(final IterationEvent e) { @@ -172,7 +173,10 @@ public class ConjugateGradientTest { } public void iterationPerformed(final IterationEvent e) { - // Do nothing + RealVector v = ((ProvidesResidual)e).getResidual(); + r.setSubVector(0, v); + v = ((IterativeLinearSolverEvent) e).getSolution(); + x.setSubVector(0, v); } public void initializationPerformed(final IterationEvent e) { @@ -180,22 +184,29 @@ public class ConjugateGradientTest { } }; solver.getIterationManager().addIterationListener(listener); - final RealVector b = new ArrayRealVector(n); for (int j = 0; j < n; j++) { b.set(0.); b.setEntry(j, 1.); - final RealVector x = solver.solve(a, b); - final RealVector y = a.operate(x); - for (int i = 0; i < n; i++) { - final double actual = b.getEntry(i) - y.getEntry(i); - final double expected = r.getEntry(i); - final double delta = 1E-6 * Math.abs(expected); - final String msg = String - .format("column %d, residual %d", i, j); - Assert.assertEquals(msg, expected, actual, delta); + boolean caught = false; + try { + solver.solve(a, b); + } catch (MaxCountExceededException e) { + caught = true; + final RealVector y = a.operate(x); + for (int i = 0; i < n; i++) { + final double actual = b.getEntry(i) - y.getEntry(i); + final double expected = r.getEntry(i); + final double delta = 1E-6 * Math.abs(expected); + final String msg = String + .format("column %d, residual %d", i, j); + Assert.assertEquals(msg, expected, actual, delta); + } } + Assert + .assertTrue("MaxCountExceededException should have been caught", + caught); } } @@ -331,7 +342,7 @@ public class ConjugateGradientTest { } } - @Test(expected = MaxCountExceededException.class) + @Test public void testPreconditionedResidual() { final int n = 10; final int maxIterations = n; @@ -340,10 +351,11 @@ public class ConjugateGradientTest { final ConjugateGradient solver; solver = new ConjugateGradient(maxIterations, 1E-15, true); final RealVector r = new ArrayRealVector(n); + final RealVector x = new ArrayRealVector(n); final IterationListener listener = new IterationListener() { public void terminationPerformed(final IterationEvent e) { - r.setSubVector(0, ((ProvidesResidual) e).getResidual()); + // Do nothing } public void iterationStarted(final IterationEvent e) { @@ -351,7 +363,10 @@ public class ConjugateGradientTest { } public void iterationPerformed(final IterationEvent e) { - // Do nothing + RealVector v = ((ProvidesResidual)e).getResidual(); + r.setSubVector(0, v); + v = ((IterativeLinearSolverEvent) e).getSolution(); + x.setSubVector(0, v); } public void initializationPerformed(final IterationEvent e) { @@ -364,20 +379,25 @@ public class ConjugateGradientTest { for (int j = 0; j < n; j++) { b.set(0.); b.setEntry(j, 1.); - final RealVector x = solver.solve(a, m, b); - final RealVector y = a.operate(x); - double rnorm = 0.; - for (int i = 0; i < n; i++) { - final double actual = b.getEntry(i) - y.getEntry(i); - final double expected = r.getEntry(i); - final double delta = 1E-6 * Math.abs(expected); - final String msg = String - .format("column %d, residual %d", i, j); - Assert.assertEquals(msg, expected, actual, delta); + + boolean caught = false; + try { + solver.solve(a, m, b); + } catch (MaxCountExceededException e) { + caught = true; + final RealVector y = a.operate(x); + for (int i = 0; i < n; i++) { + final double actual = b.getEntry(i) - y.getEntry(i); + final double expected = r.getEntry(i); + final double delta = 1E-6 * Math.abs(expected); + final String msg = String + .format("column %d, residual %d", i, j); + Assert.assertEquals(msg, expected, actual, delta); + } } - rnorm = r.getNorm(); - Assert.assertEquals("norm of residual", rnorm, r.getNorm(), - 1E-6 * Math.abs(rnorm)); + Assert + .assertTrue("MaxCountExceededException should have been caught", + caught); } }