Modifications to the ConjugateGradient class and unit tests
- altered the way iterations are counted: Incrementor is incremented prior to any modification to the current state, so that the solver is in a consistent state (accessible residual corresponds to the last estimate of the solution), even in case of MaxCountExceededException occuring. - modified some tests which were not testing anything. git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1179488 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
8fcbe82ab6
commit
61018c7997
|
@ -174,13 +174,15 @@ public class ConjugateGradient
|
||||||
manager.resetIterationCount();
|
manager.resetIterationCount();
|
||||||
final double r2max = delta * delta * b.dotProduct(b);
|
final double r2max = delta * delta * b.dotProduct(b);
|
||||||
|
|
||||||
|
// Initialization phase counts as one iteration.
|
||||||
|
manager.incrementIterationCount();
|
||||||
// p and x are constructed as copies of x0, since presumably, the type
|
// p and x are constructed as copies of x0, since presumably, the type
|
||||||
// of x is optimized for the calculation of the matrix-vector product
|
// of x is optimized for the calculation of the matrix-vector product
|
||||||
// A.x.
|
// A.x.
|
||||||
final RealVector x = x0;
|
final RealVector x = x0;
|
||||||
final RealVector p = x.copy();
|
final RealVector p = x.copy();
|
||||||
RealVector q = a.operate(p);
|
RealVector q = a.operate(p);
|
||||||
manager.incrementIterationCount();
|
|
||||||
final RealVector r = b.combine(1, -1, q);
|
final RealVector r = b.combine(1, -1, q);
|
||||||
double r2 = r.dotProduct(r);
|
double r2 = r.dotProduct(r);
|
||||||
RealVector z;
|
RealVector z;
|
||||||
|
@ -213,6 +215,7 @@ public class ConjugateGradient
|
||||||
}
|
}
|
||||||
double rhoPrev = 0.;
|
double rhoPrev = 0.;
|
||||||
while (true) {
|
while (true) {
|
||||||
|
manager.incrementIterationCount();
|
||||||
manager.fireIterationStartedEvent(event);
|
manager.fireIterationStartedEvent(event);
|
||||||
if (m != null) {
|
if (m != null) {
|
||||||
z = m.solve(r);
|
z = m.solve(r);
|
||||||
|
@ -226,13 +229,12 @@ public class ConjugateGradient
|
||||||
context.setValue(VECTOR, r);
|
context.setValue(VECTOR, r);
|
||||||
throw e;
|
throw e;
|
||||||
}
|
}
|
||||||
if (manager.getIterations() == 1) {
|
if (manager.getIterations() == 2) {
|
||||||
p.setSubVector(0, z);
|
p.setSubVector(0, z);
|
||||||
} else {
|
} else {
|
||||||
p.combineToSelf(rhoNext / rhoPrev, 1., z);
|
p.combineToSelf(rhoNext / rhoPrev, 1., z);
|
||||||
}
|
}
|
||||||
q = a.operate(p);
|
q = a.operate(p);
|
||||||
manager.incrementIterationCount();
|
|
||||||
final double pq = p.dotProduct(q);
|
final double pq = p.dotProduct(q);
|
||||||
if (check && (pq <= 0.)) {
|
if (check && (pq <= 0.)) {
|
||||||
final NonPositiveDefiniteOperatorException e;
|
final NonPositiveDefiniteOperatorException e;
|
||||||
|
|
|
@ -153,7 +153,7 @@ public class ConjugateGradientTest {
|
||||||
* due to the loss of orthogonality of the successive search directions.
|
* due to the loss of orthogonality of the successive search directions.
|
||||||
* Therefore, in the present test, the number of iterations is limited.
|
* Therefore, in the present test, the number of iterations is limited.
|
||||||
*/
|
*/
|
||||||
@Test(expected = MaxCountExceededException.class)
|
@Test
|
||||||
public void testUnpreconditionedResidual() {
|
public void testUnpreconditionedResidual() {
|
||||||
final int n = 10;
|
final int n = 10;
|
||||||
final int maxIterations = n;
|
final int maxIterations = n;
|
||||||
|
@ -161,10 +161,11 @@ public class ConjugateGradientTest {
|
||||||
final ConjugateGradient solver;
|
final ConjugateGradient solver;
|
||||||
solver = new ConjugateGradient(maxIterations, 1E-15, true);
|
solver = new ConjugateGradient(maxIterations, 1E-15, true);
|
||||||
final RealVector r = new ArrayRealVector(n);
|
final RealVector r = new ArrayRealVector(n);
|
||||||
|
final RealVector x = new ArrayRealVector(n);
|
||||||
final IterationListener listener = new IterationListener() {
|
final IterationListener listener = new IterationListener() {
|
||||||
|
|
||||||
public void terminationPerformed(final IterationEvent e) {
|
public void terminationPerformed(final IterationEvent e) {
|
||||||
r.setSubVector(0, ((ProvidesResidual) e).getResidual());
|
// Do nothing
|
||||||
}
|
}
|
||||||
|
|
||||||
public void iterationStarted(final IterationEvent e) {
|
public void iterationStarted(final IterationEvent e) {
|
||||||
|
@ -172,7 +173,10 @@ public class ConjugateGradientTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void iterationPerformed(final IterationEvent e) {
|
public void iterationPerformed(final IterationEvent e) {
|
||||||
// Do nothing
|
RealVector v = ((ProvidesResidual)e).getResidual();
|
||||||
|
r.setSubVector(0, v);
|
||||||
|
v = ((IterativeLinearSolverEvent) e).getSolution();
|
||||||
|
x.setSubVector(0, v);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void initializationPerformed(final IterationEvent e) {
|
public void initializationPerformed(final IterationEvent e) {
|
||||||
|
@ -180,22 +184,29 @@ public class ConjugateGradientTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
solver.getIterationManager().addIterationListener(listener);
|
solver.getIterationManager().addIterationListener(listener);
|
||||||
|
|
||||||
final RealVector b = new ArrayRealVector(n);
|
final RealVector b = new ArrayRealVector(n);
|
||||||
for (int j = 0; j < n; j++) {
|
for (int j = 0; j < n; j++) {
|
||||||
b.set(0.);
|
b.set(0.);
|
||||||
b.setEntry(j, 1.);
|
b.setEntry(j, 1.);
|
||||||
|
|
||||||
final RealVector x = solver.solve(a, b);
|
boolean caught = false;
|
||||||
final RealVector y = a.operate(x);
|
try {
|
||||||
for (int i = 0; i < n; i++) {
|
solver.solve(a, b);
|
||||||
final double actual = b.getEntry(i) - y.getEntry(i);
|
} catch (MaxCountExceededException e) {
|
||||||
final double expected = r.getEntry(i);
|
caught = true;
|
||||||
final double delta = 1E-6 * Math.abs(expected);
|
final RealVector y = a.operate(x);
|
||||||
final String msg = String
|
for (int i = 0; i < n; i++) {
|
||||||
.format("column %d, residual %d", i, j);
|
final double actual = b.getEntry(i) - y.getEntry(i);
|
||||||
Assert.assertEquals(msg, expected, actual, delta);
|
final double expected = r.getEntry(i);
|
||||||
|
final double delta = 1E-6 * Math.abs(expected);
|
||||||
|
final String msg = String
|
||||||
|
.format("column %d, residual %d", i, j);
|
||||||
|
Assert.assertEquals(msg, expected, actual, delta);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Assert
|
||||||
|
.assertTrue("MaxCountExceededException should have been caught",
|
||||||
|
caught);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -331,7 +342,7 @@ public class ConjugateGradientTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = MaxCountExceededException.class)
|
@Test
|
||||||
public void testPreconditionedResidual() {
|
public void testPreconditionedResidual() {
|
||||||
final int n = 10;
|
final int n = 10;
|
||||||
final int maxIterations = n;
|
final int maxIterations = n;
|
||||||
|
@ -340,10 +351,11 @@ public class ConjugateGradientTest {
|
||||||
final ConjugateGradient solver;
|
final ConjugateGradient solver;
|
||||||
solver = new ConjugateGradient(maxIterations, 1E-15, true);
|
solver = new ConjugateGradient(maxIterations, 1E-15, true);
|
||||||
final RealVector r = new ArrayRealVector(n);
|
final RealVector r = new ArrayRealVector(n);
|
||||||
|
final RealVector x = new ArrayRealVector(n);
|
||||||
final IterationListener listener = new IterationListener() {
|
final IterationListener listener = new IterationListener() {
|
||||||
|
|
||||||
public void terminationPerformed(final IterationEvent e) {
|
public void terminationPerformed(final IterationEvent e) {
|
||||||
r.setSubVector(0, ((ProvidesResidual) e).getResidual());
|
// Do nothing
|
||||||
}
|
}
|
||||||
|
|
||||||
public void iterationStarted(final IterationEvent e) {
|
public void iterationStarted(final IterationEvent e) {
|
||||||
|
@ -351,7 +363,10 @@ public class ConjugateGradientTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void iterationPerformed(final IterationEvent e) {
|
public void iterationPerformed(final IterationEvent e) {
|
||||||
// Do nothing
|
RealVector v = ((ProvidesResidual)e).getResidual();
|
||||||
|
r.setSubVector(0, v);
|
||||||
|
v = ((IterativeLinearSolverEvent) e).getSolution();
|
||||||
|
x.setSubVector(0, v);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void initializationPerformed(final IterationEvent e) {
|
public void initializationPerformed(final IterationEvent e) {
|
||||||
|
@ -364,20 +379,25 @@ public class ConjugateGradientTest {
|
||||||
for (int j = 0; j < n; j++) {
|
for (int j = 0; j < n; j++) {
|
||||||
b.set(0.);
|
b.set(0.);
|
||||||
b.setEntry(j, 1.);
|
b.setEntry(j, 1.);
|
||||||
final RealVector x = solver.solve(a, m, b);
|
|
||||||
final RealVector y = a.operate(x);
|
boolean caught = false;
|
||||||
double rnorm = 0.;
|
try {
|
||||||
for (int i = 0; i < n; i++) {
|
solver.solve(a, m, b);
|
||||||
final double actual = b.getEntry(i) - y.getEntry(i);
|
} catch (MaxCountExceededException e) {
|
||||||
final double expected = r.getEntry(i);
|
caught = true;
|
||||||
final double delta = 1E-6 * Math.abs(expected);
|
final RealVector y = a.operate(x);
|
||||||
final String msg = String
|
for (int i = 0; i < n; i++) {
|
||||||
.format("column %d, residual %d", i, j);
|
final double actual = b.getEntry(i) - y.getEntry(i);
|
||||||
Assert.assertEquals(msg, expected, actual, delta);
|
final double expected = r.getEntry(i);
|
||||||
|
final double delta = 1E-6 * Math.abs(expected);
|
||||||
|
final String msg = String
|
||||||
|
.format("column %d, residual %d", i, j);
|
||||||
|
Assert.assertEquals(msg, expected, actual, delta);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
rnorm = r.getNorm();
|
Assert
|
||||||
Assert.assertEquals("norm of residual", rnorm, r.getNorm(),
|
.assertTrue("MaxCountExceededException should have been caught",
|
||||||
1E-6 * Math.abs(rnorm));
|
caught);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue