Added method getIterations() in o.a.c.m.utils.IterationEvent (MATH-735).

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1237056 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Sebastien Brisard 2012-01-28 13:19:14 +00:00
parent 14f26e00eb
commit 908d8adc98
6 changed files with 169 additions and 128 deletions

View File

@ -86,21 +86,55 @@ public class ConjugateGradient
* @version $Id: ConjugateGradient.java 1175404 2011-09-25 14:48:18Z * @version $Id: ConjugateGradient.java 1175404 2011-09-25 14:48:18Z
* celestin $ * celestin $
*/ */
public abstract static class ConjugateGradientEvent public static class ConjugateGradientEvent
extends IterativeLinearSolverEvent extends IterativeLinearSolverEvent
implements ProvidesResidual { implements ProvidesResidual {
/** */ /** */
private static final long serialVersionUID = 6461730085343318121L; private static final long serialVersionUID = 20120128L;
/** The right-hand side vector. */
private final RealVector b;
/** The current estimate of the residual. */
private final RealVector r;
/** The current estimate of the solution. */
private final RealVector x;
/** /**
* Creates a new instance of this class. * Creates a new instance of this class.
* *
* @param source The iterative algorithm on which the event initially * @param source the iterative algorithm on which the event initially
* occurred. * occurred
* @param iterations the number of iterations performed at the time
* {@code this} event is created
* @param x the current estimate of the solution
* @param b the right-hand side vector
* @param r the current estimate of the residual
*/ */
public ConjugateGradientEvent(final Object source) { public ConjugateGradientEvent(final Object source, final int iterations, final RealVector x, final RealVector b, final RealVector r) {
super(source); super(source, iterations);
this.x = RealVector.unmodifiableRealVector(x);
this.b = RealVector.unmodifiableRealVector(b);
this.r = RealVector.unmodifiableRealVector(r);
}
/** {@inheritDoc} */
public RealVector getResidual() {
return r;
}
/** {@inheritDoc} */
@Override
public RealVector getRightHandSideVector() {
return b;
}
/** {@inheritDoc} */
@Override
public RealVector getSolution() {
return x;
} }
} }
@ -191,34 +225,18 @@ public class ConjugateGradient
} else { } else {
z = null; z = null;
} }
final IterativeLinearSolverEvent event; IterativeLinearSolverEvent evt;
event = new ConjugateGradientEvent(this) { evt = new ConjugateGradientEvent(this, manager.getIterations(), x, b, r);
manager.fireInitializationEvent(evt);
private static final long serialVersionUID = 756911840348776676L;
public RealVector getResidual() {
return ArrayRealVector.unmodifiableRealVector(r);
}
@Override
public RealVector getRightHandSideVector() {
return ArrayRealVector.unmodifiableRealVector(b);
}
@Override
public RealVector getSolution() {
return ArrayRealVector.unmodifiableRealVector(x);
}
};
manager.fireInitializationEvent(event);
if (r2 <= r2max) { if (r2 <= r2max) {
manager.fireTerminationEvent(event); manager.fireTerminationEvent(evt);
return x; return x;
} }
double rhoPrev = 0.; double rhoPrev = 0.;
while (true) { while (true) {
manager.incrementIterationCount(); manager.incrementIterationCount();
manager.fireIterationStartedEvent(event); evt = new ConjugateGradientEvent(this, manager.getIterations(), x, b, r);
manager.fireIterationStartedEvent(evt);
if (m != null) { if (m != null) {
z = m.solve(r); z = m.solve(r);
} }
@ -251,9 +269,10 @@ public class ConjugateGradient
r.combineToSelf(1., -alpha, q); r.combineToSelf(1., -alpha, q);
rhoPrev = rhoNext; rhoPrev = rhoNext;
r2 = r.dotProduct(r); r2 = r.dotProduct(r);
manager.fireIterationPerformedEvent(event); evt = new ConjugateGradientEvent(this, manager.getIterations(), x, b, r);
manager.fireIterationPerformedEvent(evt);
if (r2 <= r2max) { if (r2 <= r2max) {
manager.fireTerminationEvent(event); manager.fireTerminationEvent(evt);
return x; return x;
} }
} }

View File

@ -29,16 +29,18 @@ public abstract class IterativeLinearSolverEvent
extends IterationEvent { extends IterationEvent {
/** */ /** */
private static final long serialVersionUID = 283291016904748030L; private static final long serialVersionUID = 20120128L;
/** /**
* Creates a new instance of this class. * Creates a new instance of this class.
* *
* @param source The iterative algorithm on which the event initially * @param source the iterative algorithm on which the event initially
* occurred. * occurred
* @param iterations the number of iterations performed at the time
* {@code this} event is created
*/ */
public IterativeLinearSolverEvent(final Object source) { public IterativeLinearSolverEvent(final Object source, final int iterations) {
super(source); super(source, iterations);
} }
/** /**

View File

@ -652,6 +652,56 @@ public class SymmLQ
} }
} }
/**
* The type of all events fired by this implementation of the SYMMLQ method.
*
* @version $Id$
*/
private class SymmLQEvent extends IterativeLinearSolverEvent {
/*
* TODO This class relies dangerously on references being transparently
* updated.
*/
/** */
private static final long serialVersionUID = 20120128L;
/** A reference to the state of this solver. */
private final State state;
/**
* Creates a new instance of this class.
*
* @param source the iterative algorithm on which the event initially
* occurred
* @param state the state of this solver at the time of creation
*/
public SymmLQEvent(final Object source, final State state) {
super(source, getIterationManager().getIterations());
this.state = state;
}
@Override
public int getIterations() {
return getIterationManager().getIterations();
}
/** {@inheritDoc} */
@Override
public RealVector getRightHandSideVector() {
return RealVector.unmodifiableRealVector(state.b);
}
/** {@inheritDoc} */
@Override
public RealVector getSolution() {
final int n = state.x.getDimension();
final RealVector x = new ArrayRealVector(n);
state.refine(x);
return x;
}
}
/** The cubic root of {@link #MACH_PREC}. */ /** The cubic root of {@link #MACH_PREC}. */
private static final double CBRT_MACH_PREC; private static final double CBRT_MACH_PREC;
@ -1141,7 +1191,7 @@ public class SymmLQ
manager.incrementIterationCount(); manager.incrementIterationCount();
final State state = new State(a, m, b, x, goodb, shift); final State state = new State(a, m, b, x, goodb, shift);
final IterativeLinearSolverEvent event = createEvent(state); final IterativeLinearSolverEvent event = new SymmLQEvent(this, state);
if (state.beta1 == 0.) { if (state.beta1 == 0.) {
/* If b = 0 exactly, stop with x = 0. */ /* If b = 0 exactly, stop with x = 0. */
manager.fireTerminationEvent(event); manager.fireTerminationEvent(event);
@ -1201,36 +1251,4 @@ public class SymmLQ
IllConditionedOperatorException, MaxCountExceededException { IllConditionedOperatorException, MaxCountExceededException {
return solveInPlace(a, null, b, x, false, 0.); return solveInPlace(a, null, b, x, false, 0.);
} }
/**
* Creates the event to be fired during the solution process. Unmodifiable
* views of the RHS vector, and the current estimate of the solution are
* returned by the created event.
*
* @param state Reference to the current state of this algorithm.
* @return The newly created event.
*/
private IterativeLinearSolverEvent createEvent(final State state) {
final RealVector bb = RealVector.unmodifiableRealVector(state.b);
final IterativeLinearSolverEvent event;
event = new IterativeLinearSolverEvent(this) {
private static final long serialVersionUID = 3656926699603081076L;
@Override
public RealVector getRightHandSideVector() {
return bb;
}
@Override
public RealVector getSolution() {
final int n = state.x.getDimension();
final RealVector x = new ArrayRealVector(n);
state.refine(x);
return x;
}
};
return event;
}
} }

View File

@ -26,15 +26,31 @@ import java.util.EventObject;
*/ */
public class IterationEvent extends EventObject { public class IterationEvent extends EventObject {
/** */ /** */
private static final long serialVersionUID = -1405936936084001482L; private static final long serialVersionUID = 20120128L;
/** The number of iterations performed so far. */
private final int iterations;
/** /**
* Creates a new instance of this class. * Creates a new instance of this class.
* *
* @param source The iterative algorithm on which the event initially * @param source the iterative algorithm on which the event initially
* occurred. * occurred
* @param iterations the number of iterations performed at the time
* {@code this} event is created
*/ */
public IterationEvent(final Object source) { public IterationEvent(final Object source, final int iterations) {
super(source); super(source);
this.iterations = iterations;
} }
}
/**
* Returns the number of iterations performed at the time {@code this} event
* is created.
*
* @return the number of iterations performed
*/
public int getIterations() {
return iterations;
}
}

View File

@ -16,6 +16,8 @@
*/ */
package org.apache.commons.math.linear; package org.apache.commons.math.linear;
import java.util.Arrays;
import org.apache.commons.math.exception.DimensionMismatchException; import org.apache.commons.math.exception.DimensionMismatchException;
import org.apache.commons.math.exception.MaxCountExceededException; import org.apache.commons.math.exception.MaxCountExceededException;
import org.apache.commons.math.util.IterationEvent; import org.apache.commons.math.util.IterationEvent;
@ -453,26 +455,29 @@ public class ConjugateGradientTest {
final int maxIterations = 100; final int maxIterations = 100;
final RealLinearOperator a = new HilbertMatrix(n); final RealLinearOperator a = new HilbertMatrix(n);
final IterativeLinearSolver solver; final IterativeLinearSolver solver;
final int[] count = new int[] { /*
0, 0, 0, 0 * count[0] = number of calls to initializationPerformed
}; * count[1] = number of calls to iterationStarted
* count[2] = number of calls to iterationPerformed
* count[3] = number of calls to terminationPerformed
*/
final int[] count = new int[] {0, 0, 0, 0};
final IterationListener listener = new IterationListener() { final IterationListener listener = new IterationListener() {
public void initializationPerformed(final IterationEvent e) { public void initializationPerformed(final IterationEvent e) {
count[0] = 1; ++count[0];
count[1] = 0;
count[2] = 0;
count[3] = 0;
} }
public void iterationPerformed(final IterationEvent e) { public void iterationPerformed(final IterationEvent e) {
++count[2]; ++count[2];
Assert.assertEquals("iteration performed",
count[2], e.getIterations() - 1);
} }
public void iterationStarted(IterationEvent e) { public void iterationStarted(final IterationEvent e) {
++count[1]; ++count[1];
Assert.assertEquals("iteration started",
count[1], e.getIterations() - 1);
} }
public void terminationPerformed(final IterationEvent e) { public void terminationPerformed(final IterationEvent e) {
@ -483,17 +488,12 @@ public class ConjugateGradientTest {
solver.getIterationManager().addIterationListener(listener); solver.getIterationManager().addIterationListener(listener);
final RealVector b = new ArrayRealVector(n); final RealVector b = new ArrayRealVector(n);
for (int j = 0; j < n; j++) { for (int j = 0; j < n; j++) {
Arrays.fill(count, 0);
b.set(0.); b.set(0.);
b.setEntry(j, 1.); b.setEntry(j, 1.);
solver.solve(a, b); solver.solve(a, b);
String msg = String.format("column %d (initialization)", j); String msg = String.format("column %d (initialization)", j);
Assert.assertEquals(msg, 1, count[0]); Assert.assertEquals(msg, 1, count[0]);
msg = String.format("column %d (iterations started)", j);
Assert.assertEquals(msg, solver.getIterationManager()
.getIterations() - 1, count[1]);
msg = String.format("column %d (iterations performed)", j);
Assert.assertEquals(msg, solver.getIterationManager()
.getIterations() - 1, count[2]);
msg = String.format("column %d (finalization)", j); msg = String.format("column %d (finalization)", j);
Assert.assertEquals(msg, 1, count[3]); Assert.assertEquals(msg, 1, count[3]);
} }

View File

@ -16,6 +16,8 @@
*/ */
package org.apache.commons.math.linear; package org.apache.commons.math.linear;
import java.util.Arrays;
import org.apache.commons.math.exception.DimensionMismatchException; import org.apache.commons.math.exception.DimensionMismatchException;
import org.apache.commons.math.util.FastMath; import org.apache.commons.math.util.FastMath;
import org.apache.commons.math.util.IterationEvent; import org.apache.commons.math.util.IterationEvent;
@ -522,26 +524,29 @@ public class SymmLQTest {
final int maxIterations = 100; final int maxIterations = 100;
final RealLinearOperator a = new HilbertMatrix(n); final RealLinearOperator a = new HilbertMatrix(n);
final IterativeLinearSolver solver; final IterativeLinearSolver solver;
final int[] count = new int[] { /*
0, 0, 0, 0 * count[0] = number of calls to initializationPerformed
}; * count[1] = number of calls to iterationStarted
* count[2] = number of calls to iterationPerformed
* count[3] = number of calls to terminationPerformed
*/
final int[] count = new int[] {0, 0, 0, 0};
final IterationListener listener = new IterationListener() { final IterationListener listener = new IterationListener() {
public void initializationPerformed(final IterationEvent e) { public void initializationPerformed(final IterationEvent e) {
count[0] = 1; ++count[0];
count[1] = 0;
count[2] = 0;
count[3] = 0;
} }
public void iterationPerformed(final IterationEvent e) { public void iterationPerformed(final IterationEvent e) {
++count[2]; ++count[2];
Assert.assertEquals("iteration performed",
count[2], e.getIterations() - 1);
} }
public void iterationStarted(final IterationEvent e) { public void iterationStarted(final IterationEvent e) {
++count[1]; ++count[1];
Assert.assertEquals("iteration started",
count[1], e.getIterations() - 1);
} }
public void terminationPerformed(final IterationEvent e) { public void terminationPerformed(final IterationEvent e) {
@ -552,17 +557,12 @@ public class SymmLQTest {
solver.getIterationManager().addIterationListener(listener); solver.getIterationManager().addIterationListener(listener);
final RealVector b = new ArrayRealVector(n); final RealVector b = new ArrayRealVector(n);
for (int j = 0; j < n; j++) { for (int j = 0; j < n; j++) {
Arrays.fill(count, 0);
b.set(0.); b.set(0.);
b.setEntry(j, 1.); b.setEntry(j, 1.);
solver.solve(a, b); solver.solve(a, b);
String msg = String.format("column %d (initialization)", j); String msg = String.format("column %d (initialization)", j);
Assert.assertEquals(msg, 1, count[0]); Assert.assertEquals(msg, 1, count[0]);
msg = String.format("column %d (iterations started)", j);
Assert.assertEquals(msg, solver.getIterationManager()
.getIterations() - 1, count[1]);
msg = String.format("column %d (iterations performed)", j);
Assert.assertEquals(msg, solver.getIterationManager()
.getIterations() - 1, count[2]);
msg = String.format("column %d (finalization)", j); msg = String.format("column %d (finalization)", j);
Assert.assertEquals(msg, 1, count[3]); Assert.assertEquals(msg, 1, count[3]);
} }
@ -572,41 +572,27 @@ public class SymmLQTest {
public void testNonSelfAdjointOperator() { public void testNonSelfAdjointOperator() {
final RealLinearOperator a; final RealLinearOperator a;
a = new Array2DRowRealMatrix(new double[][] { a = new Array2DRowRealMatrix(new double[][] {
{ {1., 2., 3.},
1., 2., 3. {2., 4., 5.},
}, { {2.999, 5., 6.}
2., 4., 5.
}, {
2.999, 5., 6.
}
}); });
final RealVector b; final RealVector b;
b = new ArrayRealVector(new double[] { b = new ArrayRealVector(new double[] {1., 1., 1.});
1., 1., 1.
});
new SymmLQ(100, 1., true).solve(a, b); new SymmLQ(100, 1., true).solve(a, b);
} }
@Test(expected = NonSelfAdjointOperatorException.class) @Test(expected = NonSelfAdjointOperatorException.class)
public void testNonSelfAdjointPreconditioner() { public void testNonSelfAdjointPreconditioner() {
final RealLinearOperator a = new Array2DRowRealMatrix(new double[][] { final RealLinearOperator a = new Array2DRowRealMatrix(new double[][] {
{ {1., 2., 3.},
1., 2., 3. {2., 4., 5.},
}, { {3., 5., 6.}
2., 4., 5.
}, {
3., 5., 6.
}
}); });
final Array2DRowRealMatrix mMat; final Array2DRowRealMatrix mMat;
mMat = new Array2DRowRealMatrix(new double[][] { mMat = new Array2DRowRealMatrix(new double[][] {
{ {1., 0., 1.},
1., 0., 1. {0., 1., 0.},
}, { {0., 0., 1.}
0., 1., 0.
}, {
0., 0., 1.
}
}); });
final DecompositionSolver mSolver; final DecompositionSolver mSolver;
mSolver = new LUDecomposition(mMat).getSolver(); mSolver = new LUDecomposition(mMat).getSolver();