In o.a.c.m3.SymmLQ.State, created accessors

- RealVector getRightHandSideVector(),
  - RealVector getSolution(),
  - double getNormOfResidual(),
see MATH-761.

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/math/trunk@1303674 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Sebastien Brisard 2012-03-22 07:15:24 +00:00
parent 6f72f0544f
commit 3160b5c5cc
2 changed files with 73 additions and 38 deletions

View File

@ -307,6 +307,12 @@ public class SymmLQ
/** The value of beta[k+1] * M * P' * v[k+1]. */ /** The value of beta[k+1] * M * P' * v[k+1]. */
private RealVector r2; private RealVector r2;
/**
* The value of the updated, preconditioned residual P * r. This value is
* given by {@code min(}{@link #cgnorm}{@code , }{@link #lqnorm}{@code )}.
*/
private double rnorm;
/** Copy of the {@code shift} parameter. */ /** Copy of the {@code shift} parameter. */
private final double shift; private final double shift;
@ -331,7 +337,7 @@ public class SymmLQ
* the value of xL[k-1] if {@code goodb} is {@code false}, (xL[k-1] - * the value of xL[k-1] if {@code goodb} is {@code false}, (xL[k-1] -
* bstep[k-1] * v[1]) otherwise. * bstep[k-1] * v[1]) otherwise.
*/ */
private final RealVector x; private final RealVector xL;
/** The value of beta[k+1] * P' * v[k+1]. */ /** The value of beta[k+1] * P' * v[k+1]. */
private RealVector y; private RealVector y;
@ -375,7 +381,7 @@ public class SymmLQ
this.a = a; this.a = a;
this.minv = minv; this.minv = minv;
this.b = b; this.b = b;
this.x = x; this.xL = x;
this.goodb = goodb; this.goodb = goodb;
this.shift = shift; this.shift = shift;
this.minvb = minv == null ? b : minv.operate(b); this.minvb = minv == null ? b : minv.operate(b);
@ -477,19 +483,19 @@ public class SymmLQ
* the convergence tests involve only cgnorm, so we're unlikely to stop * the convergence tests involve only cgnorm, so we're unlikely to stop
* at an LQ point, except if the iteration limit interferes. * at an LQ point, except if the iteration limit interferes.
* *
* @param xRefined the vector to be updated with the refined value of x * @param xC the vector to be updated with the refined value of xL
*/ */
public void refine(final RealVector xRefined) { void moveToCG(final RealVector xC) {
final int n = this.x.getDimension(); final int n = this.xL.getDimension();
if (lqnorm < cgnorm) { if (lqnorm < cgnorm) {
if (!goodb) { if (!goodb) {
xRefined.setSubVector(0, this.x); xC.setSubVector(0, this.xL);
} else { } else {
final double step = bstep / beta1; final double step = bstep / beta1;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
final double bi = minvb.getEntry(i); final double bi = minvb.getEntry(i);
final double xi = this.x.getEntry(i); final double xi = this.xL.getEntry(i);
xRefined.setEntry(i, xi + step * bi); xC.setEntry(i, xi + step * bi);
} }
} }
} else { } else {
@ -500,16 +506,16 @@ public class SymmLQ
// ynorm = FastMath.sqrt(ynorm2 + zbar * zbar); // ynorm = FastMath.sqrt(ynorm2 + zbar * zbar);
if (!goodb) { if (!goodb) {
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
final double xi = this.x.getEntry(i); final double xi = this.xL.getEntry(i);
final double wi = wbar.getEntry(i); final double wi = wbar.getEntry(i);
xRefined.setEntry(i, xi + zbar * wi); xC.setEntry(i, xi + zbar * wi);
} }
} else { } else {
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
final double xi = this.x.getEntry(i); final double xi = this.xL.getEntry(i);
final double wi = wbar.getEntry(i); final double wi = wbar.getEntry(i);
final double bi = minvb.getEntry(i); final double bi = minvb.getEntry(i);
xRefined.setEntry(i, xi + zbar * wi + step * bi); xC.setEntry(i, xi + zbar * wi + step * bi);
} }
} }
} }
@ -521,7 +527,7 @@ public class SymmLQ
* 1. * 1.
*/ */
private void init() { private void init() {
this.x.set(0.); this.xL.set(0.);
/* /*
* Set up y for the first Lanczos vector. y and beta1 will be zero * Set up y for the first Lanczos vector. y and beta1 will be zero
* if b = 0. * if b = 0.
@ -696,12 +702,12 @@ public class SymmLQ
*/ */
final double zetaC = zeta * c; final double zetaC = zeta * c;
final double zetaS = zeta * s; final double zetaS = zeta * s;
final int n = x.getDimension(); final int n = xL.getDimension();
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
final double xi = x.getEntry(i); final double xi = xL.getEntry(i);
final double vi = v.getEntry(i); final double vi = v.getEntry(i);
final double wi = wbar.getEntry(i); final double wi = wbar.getEntry(i);
x.setEntry(i, xi + wi * zetaC + vi * zetaS); xL.setEntry(i, xi + wi * zetaC + vi * zetaS);
wbar.setEntry(i, wi * s - vi * c); wbar.setEntry(i, wi * s - vi * c);
} }
/* /*
@ -770,6 +776,7 @@ public class SymmLQ
*/ */
throw new SingularOperatorException(); throw new SingularOperatorException();
} }
rnorm = FastMath.min(cgnorm, lqnorm);
hasConverged = (cgnorm <= epsx) || (cgnorm <= epsr); hasConverged = (cgnorm <= epsx) || (cgnorm <= epsr);
} }
@ -778,7 +785,7 @@ public class SymmLQ
* *
* @return {@code true} if convergence of the iterations has occured * @return {@code true} if convergence of the iterations has occured
*/ */
public boolean hasConverged() { boolean hasConverged() {
return hasConverged; return hasConverged;
} }
@ -787,7 +794,7 @@ public class SymmLQ
* *
* @return the boolean value of {@code b == 0} * @return the boolean value of {@code b == 0}
*/ */
public boolean bEqualsNullVector() { boolean bEqualsNullVector() {
return bIsNull; return bIsNull;
} }
@ -797,9 +804,36 @@ public class SymmLQ
* *
* @return {@code true} if {@code beta < }{@link #MACH_PREC} * @return {@code true} if {@code beta < }{@link #MACH_PREC}
*/ */
public boolean betaEqualsZero() { boolean betaEqualsZero() {
return beta < MACH_PREC; return beta < MACH_PREC;
} }
/**
* Returns the right-hand side vector.
*
* @return the right-hand side vector, b
*/
RealVector getRightHandSideVector() {
return b;
}
/**
* Returns the current estimate of the solution (LQ point).
*
* @return the solution, xL
*/
RealVector getSolution() {
return xL;
}
/**
* Returns the norm of the updated, preconditioned residual.
*
* @return the norm of the residual, ||P * r||
*/
double getNormOfResidual() {
return rnorm;
}
} }
/** /**
@ -835,21 +869,20 @@ public class SymmLQ
/** {@inheritDoc} */ /** {@inheritDoc} */
@Override @Override
public double getNormOfResidual() { public double getNormOfResidual() {
return FastMath.min(state.cgnorm, state.lqnorm); return state.getNormOfResidual();
} }
/** {@inheritDoc} */ /** {@inheritDoc} */
@Override @Override
public RealVector getRightHandSideVector() { public RealVector getRightHandSideVector() {
return RealVector.unmodifiableRealVector(state.b); return RealVector.unmodifiableRealVector(state.getRightHandSideVector());
} }
/** {@inheritDoc} */ /** {@inheritDoc} */
@Override @Override
public RealVector getSolution() { public RealVector getSolution() {
final int n = state.x.getDimension(); final RealVector x = state.getSolution().copy();
final RealVector x = new ArrayRealVector(n); state.moveToCG(x);
state.refine(x);
return x; return x;
} }
} }
@ -1180,7 +1213,11 @@ public class SymmLQ
manager.resetIterationCount(); manager.resetIterationCount();
manager.incrementIterationCount(); manager.incrementIterationCount();
final State state = new State(a, minv, b, x, goodb, shift, delta, check); final State state = new State(a, minv, b, x.copy(), goodb, shift, delta, check);
/*
* There is no need to create a new SymmLQEvent each time the state is
* updated, as SymmLQEvent keeps a reference to the current state.
*/
final IterativeLinearSolverEvent event = new SymmLQEvent(this, state); final IterativeLinearSolverEvent event = new SymmLQEvent(this, state);
if (state.bEqualsNullVector()) { if (state.bEqualsNullVector()) {
/* If b = 0 exactly, stop with x = 0. */ /* If b = 0 exactly, stop with x = 0. */
@ -1199,14 +1236,7 @@ public class SymmLQ
manager.fireIterationPerformedEvent(event); manager.fireIterationPerformedEvent(event);
} while (!state.hasConverged()); } while (!state.hasConverged());
} }
state.refine(x); state.moveToCG(x);
/*
* The following two lines are a hack because state.x is now refined,
* so further calls to state.refine() (via event.getSolution()) should
* *not* return an altered value of state.x.
*/
state.bstep = 0.;
state.gammaZeta = 0.;
manager.fireTerminationEvent(event); manager.fireTerminationEvent(event);
return x; return x;
} }

View File

@ -19,6 +19,7 @@ package org.apache.commons.math3.linear;
import java.util.Arrays; import java.util.Arrays;
import org.apache.commons.math3.exception.DimensionMismatchException; import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.MathUnsupportedOperationException;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.IterationEvent; import org.apache.commons.math3.util.IterationEvent;
import org.apache.commons.math3.util.IterationListener; import org.apache.commons.math3.util.IterationListener;
@ -496,18 +497,21 @@ public class SymmLQTest {
public void iterationPerformed(final IterationEvent e) { public void iterationPerformed(final IterationEvent e) {
++count[2]; ++count[2];
Assert.assertEquals("iteration performed", Assert.assertEquals("iteration performed",
count[2], e.getIterations() - 1); count[2],
e.getIterations() - 1);
} }
public void iterationStarted(final IterationEvent e) { public void iterationStarted(final IterationEvent e) {
++count[1]; ++count[1];
Assert.assertEquals("iteration started", Assert.assertEquals("iteration started",
count[1], e.getIterations() - 1); count[1],
e.getIterations() - 1);
} }
public void terminationPerformed(final IterationEvent e) { public void terminationPerformed(final IterationEvent e) {
++count[3]; ++count[3];
final IterativeLinearSolverEvent ilse = (IterativeLinearSolverEvent) e; final IterativeLinearSolverEvent ilse;
ilse = (IterativeLinearSolverEvent) e;
xFromListener.setSubVector(0, ilse.getSolution()); xFromListener.setSubVector(0, ilse.getSolution());
} }
}; };
@ -524,8 +528,9 @@ public class SymmLQTest {
msg = String.format("column %d (finalization)", j); msg = String.format("column %d (finalization)", j);
Assert.assertEquals(msg, 1, count[3]); Assert.assertEquals(msg, 1, count[3]);
/* /*
* Check that solution is not "over-refined". When the last iteration has * Check that solution is not "over-refined". When the last
* occurred, no further refinement should be performed. * iteration has occurred, no further refinement should be
* performed.
*/ */
for (int i = 0; i < n; i++){ for (int i = 0; i < n; i++){
msg = String.format("row %d, column %d", i, j); msg = String.format("row %d, column %d", i, j);