diff --git a/src/main/java/org/apache/commons/math/linear/ConjugateGradient.java b/src/main/java/org/apache/commons/math/linear/ConjugateGradient.java index b83aa2960..0822a2f54 100644 --- a/src/main/java/org/apache/commons/math/linear/ConjugateGradient.java +++ b/src/main/java/org/apache/commons/math/linear/ConjugateGradient.java @@ -86,21 +86,55 @@ public class ConjugateGradient * @version $Id: ConjugateGradient.java 1175404 2011-09-25 14:48:18Z * celestin $ */ - public abstract static class ConjugateGradientEvent + public static class ConjugateGradientEvent extends IterativeLinearSolverEvent implements ProvidesResidual { /** */ - private static final long serialVersionUID = 6461730085343318121L; + private static final long serialVersionUID = 20120128L; + + /** The right-hand side vector. */ + private final RealVector b; + + /** The current estimate of the residual. */ + private final RealVector r; + + /** The current estimate of the solution. */ + private final RealVector x; /** * Creates a new instance of this class. * - * @param source The iterative algorithm on which the event initially - * occurred. + * @param source the iterative algorithm on which the event initially + * occurred + * @param iterations the number of iterations performed at the time + * {@code this} event is created + * @param x the current estimate of the solution + * @param b the right-hand side vector + * @param r the current estimate of the residual */ - public ConjugateGradientEvent(final Object source) { - super(source); + public ConjugateGradientEvent(final Object source, final int iterations, final RealVector x, final RealVector b, final RealVector r) { + super(source, iterations); + this.x = RealVector.unmodifiableRealVector(x); + this.b = RealVector.unmodifiableRealVector(b); + this.r = RealVector.unmodifiableRealVector(r); + } + + /** {@inheritDoc} */ + public RealVector getResidual() { + return r; + } + + /** {@inheritDoc} */ + @Override + public RealVector getRightHandSideVector() { + return b; + } + + /** {@inheritDoc} */ + @Override + public RealVector getSolution() { + return x; } } @@ -191,34 +225,18 @@ public class ConjugateGradient } else { z = null; } - final IterativeLinearSolverEvent event; - event = new ConjugateGradientEvent(this) { - - private static final long serialVersionUID = 756911840348776676L; - - public RealVector getResidual() { - return ArrayRealVector.unmodifiableRealVector(r); - } - - @Override - public RealVector getRightHandSideVector() { - return ArrayRealVector.unmodifiableRealVector(b); - } - - @Override - public RealVector getSolution() { - return ArrayRealVector.unmodifiableRealVector(x); - } - }; - manager.fireInitializationEvent(event); + IterativeLinearSolverEvent evt; + evt = new ConjugateGradientEvent(this, manager.getIterations(), x, b, r); + manager.fireInitializationEvent(evt); if (r2 <= r2max) { - manager.fireTerminationEvent(event); + manager.fireTerminationEvent(evt); return x; } double rhoPrev = 0.; while (true) { manager.incrementIterationCount(); - manager.fireIterationStartedEvent(event); + evt = new ConjugateGradientEvent(this, manager.getIterations(), x, b, r); + manager.fireIterationStartedEvent(evt); if (m != null) { z = m.solve(r); } @@ -251,9 +269,10 @@ public class ConjugateGradient r.combineToSelf(1., -alpha, q); rhoPrev = rhoNext; r2 = r.dotProduct(r); - manager.fireIterationPerformedEvent(event); + evt = new ConjugateGradientEvent(this, manager.getIterations(), x, b, r); + manager.fireIterationPerformedEvent(evt); if (r2 <= r2max) { - manager.fireTerminationEvent(event); + manager.fireTerminationEvent(evt); return x; } } diff --git a/src/main/java/org/apache/commons/math/linear/IterativeLinearSolverEvent.java b/src/main/java/org/apache/commons/math/linear/IterativeLinearSolverEvent.java index 25b4b0d90..6b9845215 100644 --- a/src/main/java/org/apache/commons/math/linear/IterativeLinearSolverEvent.java +++ b/src/main/java/org/apache/commons/math/linear/IterativeLinearSolverEvent.java @@ -29,16 +29,18 @@ public abstract class IterativeLinearSolverEvent extends IterationEvent { /** */ - private static final long serialVersionUID = 283291016904748030L; + private static final long serialVersionUID = 20120128L; /** * Creates a new instance of this class. * - * @param source The iterative algorithm on which the event initially - * occurred. + * @param source the iterative algorithm on which the event initially + * occurred + * @param iterations the number of iterations performed at the time + * {@code this} event is created */ - public IterativeLinearSolverEvent(final Object source) { - super(source); + public IterativeLinearSolverEvent(final Object source, final int iterations) { + super(source, iterations); } /** diff --git a/src/main/java/org/apache/commons/math/linear/SymmLQ.java b/src/main/java/org/apache/commons/math/linear/SymmLQ.java index 7c4fd926a..ab2a66cd7 100644 --- a/src/main/java/org/apache/commons/math/linear/SymmLQ.java +++ b/src/main/java/org/apache/commons/math/linear/SymmLQ.java @@ -652,6 +652,56 @@ public class SymmLQ } } + /** + * The type of all events fired by this implementation of the SYMMLQ method. + * + * @version $Id$ + */ + private class SymmLQEvent extends IterativeLinearSolverEvent { + /* + * TODO This class relies dangerously on references being transparently + * updated. + */ + + /** */ + private static final long serialVersionUID = 20120128L; + + /** A reference to the state of this solver. */ + private final State state; + + /** + * Creates a new instance of this class. + * + * @param source the iterative algorithm on which the event initially + * occurred + * @param state the state of this solver at the time of creation + */ + public SymmLQEvent(final Object source, final State state) { + super(source, getIterationManager().getIterations()); + this.state = state; + } + + @Override + public int getIterations() { + return getIterationManager().getIterations(); + } + + /** {@inheritDoc} */ + @Override + public RealVector getRightHandSideVector() { + return RealVector.unmodifiableRealVector(state.b); + } + + /** {@inheritDoc} */ + @Override + public RealVector getSolution() { + final int n = state.x.getDimension(); + final RealVector x = new ArrayRealVector(n); + state.refine(x); + return x; + } + } + /** The cubic root of {@link #MACH_PREC}. */ private static final double CBRT_MACH_PREC; @@ -1141,7 +1191,7 @@ public class SymmLQ manager.incrementIterationCount(); final State state = new State(a, m, b, x, goodb, shift); - final IterativeLinearSolverEvent event = createEvent(state); + final IterativeLinearSolverEvent event = new SymmLQEvent(this, state); if (state.beta1 == 0.) { /* If b = 0 exactly, stop with x = 0. */ manager.fireTerminationEvent(event); @@ -1201,36 +1251,4 @@ public class SymmLQ IllConditionedOperatorException, MaxCountExceededException { return solveInPlace(a, null, b, x, false, 0.); } - - /** - * Creates the event to be fired during the solution process. Unmodifiable - * views of the RHS vector, and the current estimate of the solution are - * returned by the created event. - * - * @param state Reference to the current state of this algorithm. - * @return The newly created event. - */ - private IterativeLinearSolverEvent createEvent(final State state) { - final RealVector bb = RealVector.unmodifiableRealVector(state.b); - - final IterativeLinearSolverEvent event; - event = new IterativeLinearSolverEvent(this) { - - private static final long serialVersionUID = 3656926699603081076L; - - @Override - public RealVector getRightHandSideVector() { - return bb; - } - - @Override - public RealVector getSolution() { - final int n = state.x.getDimension(); - final RealVector x = new ArrayRealVector(n); - state.refine(x); - return x; - } - }; - return event; - } } diff --git a/src/main/java/org/apache/commons/math/util/IterationEvent.java b/src/main/java/org/apache/commons/math/util/IterationEvent.java index 14dafe5ba..d0bc4184b 100644 --- a/src/main/java/org/apache/commons/math/util/IterationEvent.java +++ b/src/main/java/org/apache/commons/math/util/IterationEvent.java @@ -26,15 +26,31 @@ import java.util.EventObject; */ public class IterationEvent extends EventObject { /** */ - private static final long serialVersionUID = -1405936936084001482L; + private static final long serialVersionUID = 20120128L; + + /** The number of iterations performed so far. */ + private final int iterations; /** * Creates a new instance of this class. * - * @param source The iterative algorithm on which the event initially - * occurred. + * @param source the iterative algorithm on which the event initially + * occurred + * @param iterations the number of iterations performed at the time + * {@code this} event is created */ - public IterationEvent(final Object source) { + public IterationEvent(final Object source, final int iterations) { super(source); + this.iterations = iterations; } -} + + /** + * Returns the number of iterations performed at the time {@code this} event + * is created. + * + * @return the number of iterations performed + */ + public int getIterations() { + return iterations; + } + } diff --git a/src/test/java/org/apache/commons/math/linear/ConjugateGradientTest.java b/src/test/java/org/apache/commons/math/linear/ConjugateGradientTest.java index 8f313a150..6e8315eb2 100644 --- a/src/test/java/org/apache/commons/math/linear/ConjugateGradientTest.java +++ b/src/test/java/org/apache/commons/math/linear/ConjugateGradientTest.java @@ -16,6 +16,8 @@ */ package org.apache.commons.math.linear; +import java.util.Arrays; + import org.apache.commons.math.exception.DimensionMismatchException; import org.apache.commons.math.exception.MaxCountExceededException; import org.apache.commons.math.util.IterationEvent; @@ -453,26 +455,29 @@ public class ConjugateGradientTest { final int maxIterations = 100; final RealLinearOperator a = new HilbertMatrix(n); final IterativeLinearSolver solver; - final int[] count = new int[] { - 0, 0, 0, 0 - }; + /* + * count[0] = number of calls to initializationPerformed + * count[1] = number of calls to iterationStarted + * count[2] = number of calls to iterationPerformed + * count[3] = number of calls to terminationPerformed + */ + final int[] count = new int[] {0, 0, 0, 0}; final IterationListener listener = new IterationListener() { public void initializationPerformed(final IterationEvent e) { - count[0] = 1; - count[1] = 0; - count[2] = 0; - count[3] = 0; - + ++count[0]; } public void iterationPerformed(final IterationEvent e) { ++count[2]; + Assert.assertEquals("iteration performed", + count[2], e.getIterations() - 1); } - public void iterationStarted(IterationEvent e) { + public void iterationStarted(final IterationEvent e) { ++count[1]; - + Assert.assertEquals("iteration started", + count[1], e.getIterations() - 1); } public void terminationPerformed(final IterationEvent e) { @@ -483,17 +488,12 @@ public class ConjugateGradientTest { solver.getIterationManager().addIterationListener(listener); final RealVector b = new ArrayRealVector(n); for (int j = 0; j < n; j++) { + Arrays.fill(count, 0); b.set(0.); b.setEntry(j, 1.); solver.solve(a, b); String msg = String.format("column %d (initialization)", j); Assert.assertEquals(msg, 1, count[0]); - msg = String.format("column %d (iterations started)", j); - Assert.assertEquals(msg, solver.getIterationManager() - .getIterations() - 1, count[1]); - msg = String.format("column %d (iterations performed)", j); - Assert.assertEquals(msg, solver.getIterationManager() - .getIterations() - 1, count[2]); msg = String.format("column %d (finalization)", j); Assert.assertEquals(msg, 1, count[3]); } diff --git a/src/test/java/org/apache/commons/math/linear/SymmLQTest.java b/src/test/java/org/apache/commons/math/linear/SymmLQTest.java index c9c68cbff..5d344e42f 100644 --- a/src/test/java/org/apache/commons/math/linear/SymmLQTest.java +++ b/src/test/java/org/apache/commons/math/linear/SymmLQTest.java @@ -16,6 +16,8 @@ */ package org.apache.commons.math.linear; +import java.util.Arrays; + import org.apache.commons.math.exception.DimensionMismatchException; import org.apache.commons.math.util.FastMath; import org.apache.commons.math.util.IterationEvent; @@ -522,26 +524,29 @@ public class SymmLQTest { final int maxIterations = 100; final RealLinearOperator a = new HilbertMatrix(n); final IterativeLinearSolver solver; - final int[] count = new int[] { - 0, 0, 0, 0 - }; + /* + * count[0] = number of calls to initializationPerformed + * count[1] = number of calls to iterationStarted + * count[2] = number of calls to iterationPerformed + * count[3] = number of calls to terminationPerformed + */ + final int[] count = new int[] {0, 0, 0, 0}; final IterationListener listener = new IterationListener() { public void initializationPerformed(final IterationEvent e) { - count[0] = 1; - count[1] = 0; - count[2] = 0; - count[3] = 0; - + ++count[0]; } public void iterationPerformed(final IterationEvent e) { ++count[2]; + Assert.assertEquals("iteration performed", + count[2], e.getIterations() - 1); } public void iterationStarted(final IterationEvent e) { ++count[1]; - + Assert.assertEquals("iteration started", + count[1], e.getIterations() - 1); } public void terminationPerformed(final IterationEvent e) { @@ -552,17 +557,12 @@ public class SymmLQTest { solver.getIterationManager().addIterationListener(listener); final RealVector b = new ArrayRealVector(n); for (int j = 0; j < n; j++) { + Arrays.fill(count, 0); b.set(0.); b.setEntry(j, 1.); solver.solve(a, b); String msg = String.format("column %d (initialization)", j); Assert.assertEquals(msg, 1, count[0]); - msg = String.format("column %d (iterations started)", j); - Assert.assertEquals(msg, solver.getIterationManager() - .getIterations() - 1, count[1]); - msg = String.format("column %d (iterations performed)", j); - Assert.assertEquals(msg, solver.getIterationManager() - .getIterations() - 1, count[2]); msg = String.format("column %d (finalization)", j); Assert.assertEquals(msg, 1, count[3]); } @@ -572,41 +572,27 @@ public class SymmLQTest { public void testNonSelfAdjointOperator() { final RealLinearOperator a; a = new Array2DRowRealMatrix(new double[][] { - { - 1., 2., 3. - }, { - 2., 4., 5. - }, { - 2.999, 5., 6. - } + {1., 2., 3.}, + {2., 4., 5.}, + {2.999, 5., 6.} }); final RealVector b; - b = new ArrayRealVector(new double[] { - 1., 1., 1. - }); + b = new ArrayRealVector(new double[] {1., 1., 1.}); new SymmLQ(100, 1., true).solve(a, b); } @Test(expected = NonSelfAdjointOperatorException.class) public void testNonSelfAdjointPreconditioner() { final RealLinearOperator a = new Array2DRowRealMatrix(new double[][] { - { - 1., 2., 3. - }, { - 2., 4., 5. - }, { - 3., 5., 6. - } + {1., 2., 3.}, + {2., 4., 5.}, + {3., 5., 6.} }); final Array2DRowRealMatrix mMat; mMat = new Array2DRowRealMatrix(new double[][] { - { - 1., 0., 1. - }, { - 0., 1., 0. - }, { - 0., 0., 1. - } + {1., 0., 1.}, + {0., 1., 0.}, + {0., 0., 1.} }); final DecompositionSolver mSolver; mSolver = new LUDecomposition(mMat).getSolver();