Added method getIterations() in o.a.c.m.utils.IterationEvent (MATH-735).

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1237056 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Sebastien Brisard 2012-01-28 13:19:14 +00:00
parent 14f26e00eb
commit 908d8adc98
6 changed files with 169 additions and 128 deletions

View File

@ -86,21 +86,55 @@ public class ConjugateGradient
* @version $Id: ConjugateGradient.java 1175404 2011-09-25 14:48:18Z
* celestin $
*/
public abstract static class ConjugateGradientEvent
public static class ConjugateGradientEvent
extends IterativeLinearSolverEvent
implements ProvidesResidual {
/** */
private static final long serialVersionUID = 6461730085343318121L;
private static final long serialVersionUID = 20120128L;
/** The right-hand side vector. */
private final RealVector b;
/** The current estimate of the residual. */
private final RealVector r;
/** The current estimate of the solution. */
private final RealVector x;
/**
* Creates a new instance of this class.
*
* @param source The iterative algorithm on which the event initially
* occurred.
* @param source the iterative algorithm on which the event initially
* occurred
* @param iterations the number of iterations performed at the time
* {@code this} event is created
* @param x the current estimate of the solution
* @param b the right-hand side vector
* @param r the current estimate of the residual
*/
public ConjugateGradientEvent(final Object source) {
super(source);
public ConjugateGradientEvent(final Object source, final int iterations, final RealVector x, final RealVector b, final RealVector r) {
super(source, iterations);
this.x = RealVector.unmodifiableRealVector(x);
this.b = RealVector.unmodifiableRealVector(b);
this.r = RealVector.unmodifiableRealVector(r);
}
/** {@inheritDoc} */
public RealVector getResidual() {
return r;
}
/** {@inheritDoc} */
@Override
public RealVector getRightHandSideVector() {
return b;
}
/** {@inheritDoc} */
@Override
public RealVector getSolution() {
return x;
}
}
@ -191,34 +225,18 @@ public class ConjugateGradient
} else {
z = null;
}
final IterativeLinearSolverEvent event;
event = new ConjugateGradientEvent(this) {
private static final long serialVersionUID = 756911840348776676L;
public RealVector getResidual() {
return ArrayRealVector.unmodifiableRealVector(r);
}
@Override
public RealVector getRightHandSideVector() {
return ArrayRealVector.unmodifiableRealVector(b);
}
@Override
public RealVector getSolution() {
return ArrayRealVector.unmodifiableRealVector(x);
}
};
manager.fireInitializationEvent(event);
IterativeLinearSolverEvent evt;
evt = new ConjugateGradientEvent(this, manager.getIterations(), x, b, r);
manager.fireInitializationEvent(evt);
if (r2 <= r2max) {
manager.fireTerminationEvent(event);
manager.fireTerminationEvent(evt);
return x;
}
double rhoPrev = 0.;
while (true) {
manager.incrementIterationCount();
manager.fireIterationStartedEvent(event);
evt = new ConjugateGradientEvent(this, manager.getIterations(), x, b, r);
manager.fireIterationStartedEvent(evt);
if (m != null) {
z = m.solve(r);
}
@ -251,9 +269,10 @@ public class ConjugateGradient
r.combineToSelf(1., -alpha, q);
rhoPrev = rhoNext;
r2 = r.dotProduct(r);
manager.fireIterationPerformedEvent(event);
evt = new ConjugateGradientEvent(this, manager.getIterations(), x, b, r);
manager.fireIterationPerformedEvent(evt);
if (r2 <= r2max) {
manager.fireTerminationEvent(event);
manager.fireTerminationEvent(evt);
return x;
}
}

View File

@ -29,16 +29,18 @@ public abstract class IterativeLinearSolverEvent
extends IterationEvent {
/** */
private static final long serialVersionUID = 283291016904748030L;
private static final long serialVersionUID = 20120128L;
/**
* Creates a new instance of this class.
*
* @param source The iterative algorithm on which the event initially
* occurred.
* @param source the iterative algorithm on which the event initially
* occurred
* @param iterations the number of iterations performed at the time
* {@code this} event is created
*/
public IterativeLinearSolverEvent(final Object source) {
super(source);
public IterativeLinearSolverEvent(final Object source, final int iterations) {
super(source, iterations);
}
/**

View File

@ -652,6 +652,56 @@ public class SymmLQ
}
}
/**
* The type of all events fired by this implementation of the SYMMLQ method.
*
* @version $Id$
*/
private class SymmLQEvent extends IterativeLinearSolverEvent {
/*
* TODO This class relies dangerously on references being transparently
* updated.
*/
/** */
private static final long serialVersionUID = 20120128L;
/** A reference to the state of this solver. */
private final State state;
/**
* Creates a new instance of this class.
*
* @param source the iterative algorithm on which the event initially
* occurred
* @param state the state of this solver at the time of creation
*/
public SymmLQEvent(final Object source, final State state) {
super(source, getIterationManager().getIterations());
this.state = state;
}
@Override
public int getIterations() {
return getIterationManager().getIterations();
}
/** {@inheritDoc} */
@Override
public RealVector getRightHandSideVector() {
return RealVector.unmodifiableRealVector(state.b);
}
/** {@inheritDoc} */
@Override
public RealVector getSolution() {
final int n = state.x.getDimension();
final RealVector x = new ArrayRealVector(n);
state.refine(x);
return x;
}
}
/** The cubic root of {@link #MACH_PREC}. */
private static final double CBRT_MACH_PREC;
@ -1141,7 +1191,7 @@ public class SymmLQ
manager.incrementIterationCount();
final State state = new State(a, m, b, x, goodb, shift);
final IterativeLinearSolverEvent event = createEvent(state);
final IterativeLinearSolverEvent event = new SymmLQEvent(this, state);
if (state.beta1 == 0.) {
/* If b = 0 exactly, stop with x = 0. */
manager.fireTerminationEvent(event);
@ -1201,36 +1251,4 @@ public class SymmLQ
IllConditionedOperatorException, MaxCountExceededException {
return solveInPlace(a, null, b, x, false, 0.);
}
/**
* Creates the event to be fired during the solution process. Unmodifiable
* views of the RHS vector, and the current estimate of the solution are
* returned by the created event.
*
* @param state Reference to the current state of this algorithm.
* @return The newly created event.
*/
private IterativeLinearSolverEvent createEvent(final State state) {
final RealVector bb = RealVector.unmodifiableRealVector(state.b);
final IterativeLinearSolverEvent event;
event = new IterativeLinearSolverEvent(this) {
private static final long serialVersionUID = 3656926699603081076L;
@Override
public RealVector getRightHandSideVector() {
return bb;
}
@Override
public RealVector getSolution() {
final int n = state.x.getDimension();
final RealVector x = new ArrayRealVector(n);
state.refine(x);
return x;
}
};
return event;
}
}

View File

@ -26,15 +26,31 @@ import java.util.EventObject;
*/
public class IterationEvent extends EventObject {
/** */
private static final long serialVersionUID = -1405936936084001482L;
private static final long serialVersionUID = 20120128L;
/** The number of iterations performed so far. */
private final int iterations;
/**
* Creates a new instance of this class.
*
* @param source The iterative algorithm on which the event initially
* occurred.
* @param source the iterative algorithm on which the event initially
* occurred
* @param iterations the number of iterations performed at the time
* {@code this} event is created
*/
public IterationEvent(final Object source) {
public IterationEvent(final Object source, final int iterations) {
super(source);
this.iterations = iterations;
}
}
/**
* Returns the number of iterations performed at the time {@code this} event
* is created.
*
* @return the number of iterations performed
*/
public int getIterations() {
return iterations;
}
}

View File

@ -16,6 +16,8 @@
*/
package org.apache.commons.math.linear;
import java.util.Arrays;
import org.apache.commons.math.exception.DimensionMismatchException;
import org.apache.commons.math.exception.MaxCountExceededException;
import org.apache.commons.math.util.IterationEvent;
@ -453,26 +455,29 @@ public class ConjugateGradientTest {
final int maxIterations = 100;
final RealLinearOperator a = new HilbertMatrix(n);
final IterativeLinearSolver solver;
final int[] count = new int[] {
0, 0, 0, 0
};
/*
* count[0] = number of calls to initializationPerformed
* count[1] = number of calls to iterationStarted
* count[2] = number of calls to iterationPerformed
* count[3] = number of calls to terminationPerformed
*/
final int[] count = new int[] {0, 0, 0, 0};
final IterationListener listener = new IterationListener() {
public void initializationPerformed(final IterationEvent e) {
count[0] = 1;
count[1] = 0;
count[2] = 0;
count[3] = 0;
++count[0];
}
public void iterationPerformed(final IterationEvent e) {
++count[2];
Assert.assertEquals("iteration performed",
count[2], e.getIterations() - 1);
}
public void iterationStarted(IterationEvent e) {
public void iterationStarted(final IterationEvent e) {
++count[1];
Assert.assertEquals("iteration started",
count[1], e.getIterations() - 1);
}
public void terminationPerformed(final IterationEvent e) {
@ -483,17 +488,12 @@ public class ConjugateGradientTest {
solver.getIterationManager().addIterationListener(listener);
final RealVector b = new ArrayRealVector(n);
for (int j = 0; j < n; j++) {
Arrays.fill(count, 0);
b.set(0.);
b.setEntry(j, 1.);
solver.solve(a, b);
String msg = String.format("column %d (initialization)", j);
Assert.assertEquals(msg, 1, count[0]);
msg = String.format("column %d (iterations started)", j);
Assert.assertEquals(msg, solver.getIterationManager()
.getIterations() - 1, count[1]);
msg = String.format("column %d (iterations performed)", j);
Assert.assertEquals(msg, solver.getIterationManager()
.getIterations() - 1, count[2]);
msg = String.format("column %d (finalization)", j);
Assert.assertEquals(msg, 1, count[3]);
}

View File

@ -16,6 +16,8 @@
*/
package org.apache.commons.math.linear;
import java.util.Arrays;
import org.apache.commons.math.exception.DimensionMismatchException;
import org.apache.commons.math.util.FastMath;
import org.apache.commons.math.util.IterationEvent;
@ -522,26 +524,29 @@ public class SymmLQTest {
final int maxIterations = 100;
final RealLinearOperator a = new HilbertMatrix(n);
final IterativeLinearSolver solver;
final int[] count = new int[] {
0, 0, 0, 0
};
/*
* count[0] = number of calls to initializationPerformed
* count[1] = number of calls to iterationStarted
* count[2] = number of calls to iterationPerformed
* count[3] = number of calls to terminationPerformed
*/
final int[] count = new int[] {0, 0, 0, 0};
final IterationListener listener = new IterationListener() {
public void initializationPerformed(final IterationEvent e) {
count[0] = 1;
count[1] = 0;
count[2] = 0;
count[3] = 0;
++count[0];
}
public void iterationPerformed(final IterationEvent e) {
++count[2];
Assert.assertEquals("iteration performed",
count[2], e.getIterations() - 1);
}
public void iterationStarted(final IterationEvent e) {
++count[1];
Assert.assertEquals("iteration started",
count[1], e.getIterations() - 1);
}
public void terminationPerformed(final IterationEvent e) {
@ -552,17 +557,12 @@ public class SymmLQTest {
solver.getIterationManager().addIterationListener(listener);
final RealVector b = new ArrayRealVector(n);
for (int j = 0; j < n; j++) {
Arrays.fill(count, 0);
b.set(0.);
b.setEntry(j, 1.);
solver.solve(a, b);
String msg = String.format("column %d (initialization)", j);
Assert.assertEquals(msg, 1, count[0]);
msg = String.format("column %d (iterations started)", j);
Assert.assertEquals(msg, solver.getIterationManager()
.getIterations() - 1, count[1]);
msg = String.format("column %d (iterations performed)", j);
Assert.assertEquals(msg, solver.getIterationManager()
.getIterations() - 1, count[2]);
msg = String.format("column %d (finalization)", j);
Assert.assertEquals(msg, 1, count[3]);
}
@ -572,41 +572,27 @@ public class SymmLQTest {
public void testNonSelfAdjointOperator() {
final RealLinearOperator a;
a = new Array2DRowRealMatrix(new double[][] {
{
1., 2., 3.
}, {
2., 4., 5.
}, {
2.999, 5., 6.
}
{1., 2., 3.},
{2., 4., 5.},
{2.999, 5., 6.}
});
final RealVector b;
b = new ArrayRealVector(new double[] {
1., 1., 1.
});
b = new ArrayRealVector(new double[] {1., 1., 1.});
new SymmLQ(100, 1., true).solve(a, b);
}
@Test(expected = NonSelfAdjointOperatorException.class)
public void testNonSelfAdjointPreconditioner() {
final RealLinearOperator a = new Array2DRowRealMatrix(new double[][] {
{
1., 2., 3.
}, {
2., 4., 5.
}, {
3., 5., 6.
}
{1., 2., 3.},
{2., 4., 5.},
{3., 5., 6.}
});
final Array2DRowRealMatrix mMat;
mMat = new Array2DRowRealMatrix(new double[][] {
{
1., 0., 1.
}, {
0., 1., 0.
}, {
0., 0., 1.
}
{1., 0., 1.},
{0., 1., 0.},
{0., 0., 1.}
});
final DecompositionSolver mSolver;
mSolver = new LUDecomposition(mMat).getSolver();