MATH-1172: "SimpleCurveFitter" as parent class for curve fitter implementations.

This commit is contained in:
Gilles Sadowski 2021-05-29 00:34:28 +02:00
parent 9146f7abe2
commit 1d9670cb12
7 changed files with 262 additions and 483 deletions

View File

@ -16,22 +16,15 @@
*/
package org.apache.commons.math4.legacy.fitting;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Collection;
import org.apache.commons.math4.legacy.analysis.function.Gaussian;
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.legacy.exception.NullArgumentException;
import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
import org.apache.commons.math4.legacy.exception.OutOfRangeException;
import org.apache.commons.math4.legacy.exception.ZeroException;
import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
import org.apache.commons.math4.legacy.util.FastMath;
/**
@ -69,7 +62,7 @@ import org.apache.commons.math4.legacy.util.FastMath;
*
* @since 3.3
*/
public class GaussianCurveFitter extends AbstractCurveFitter {
public class GaussianCurveFitter extends SimpleCurveFitter {
/** Parametric function to be fitted. */
private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() {
/** {@inheritDoc} */
@ -98,10 +91,6 @@ public class GaussianCurveFitter extends AbstractCurveFitter {
return v;
}
};
/** Initial guess. */
private final double[] initialGuess;
/** Maximum number of iterations of the optimization algorithm. */
private final int maxIter;
/**
* Constructor used by the factory methods.
@ -112,8 +101,7 @@ public class GaussianCurveFitter extends AbstractCurveFitter {
*/
private GaussianCurveFitter(double[] initialGuess,
int maxIter) {
this.initialGuess = initialGuess;
this.maxIter = maxIter;
super(FUNCTION, initialGuess, new ParameterGuesser(), maxIter);
}
/**
@ -131,87 +119,28 @@ public class GaussianCurveFitter extends AbstractCurveFitter {
return new GaussianCurveFitter(null, Integer.MAX_VALUE);
}
/**
* Configure the start point (initial guess).
* @param newStart new start point (initial guess)
* @return a new instance.
*/
public GaussianCurveFitter withStartPoint(double[] newStart) {
return new GaussianCurveFitter(newStart.clone(),
maxIter);
}
/**
* Configure the maximum number of iterations.
* @param newMaxIter maximum number of iterations
* @return a new instance.
*/
public GaussianCurveFitter withMaxIterations(int newMaxIter) {
return new GaussianCurveFitter(initialGuess,
newMaxIter);
}
/** {@inheritDoc} */
@Override
protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
// Prepare least-squares problem.
final int len = observations.size();
final double[] target = new double[len];
final double[] weights = new double[len];
int i = 0;
for (WeightedObservedPoint obs : observations) {
target[i] = obs.getY();
weights[i] = obs.getWeight();
++i;
}
final AbstractCurveFitter.TheoreticalValuesFunction model =
new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations);
final double[] startPoint = initialGuess != null ?
initialGuess :
// Compute estimation.
new ParameterGuesser(observations).guess();
// Return a new least squares problem set up to fit a Gaussian curve to the
// observed points.
return new LeastSquaresBuilder().
maxEvaluations(Integer.MAX_VALUE).
maxIterations(maxIter).
start(startPoint).
target(target).
weight(new DiagonalMatrix(weights)).
model(model.getModelFunction(), model.getModelFunctionJacobian()).
build();
}
/**
* Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
* of a {@link org.apache.commons.math4.legacy.analysis.function.Gaussian.Parametric}
* based on the specified observed points.
*/
public static class ParameterGuesser {
/** Normalization factor. */
private final double norm;
/** Mean. */
private final double mean;
/** Standard deviation. */
private final double sigma;
public static class ParameterGuesser extends SimpleCurveFitter.ParameterGuesser {
/**
* Constructs instance with the specified observed points.
* {@inheritDoc}
*
* @param observations Observed points from which to guess the
* parameters of the Gaussian.
* @return the guessed parameters, in the following order:
* <ul>
* <li>Normalization factor</li>
* <li>Mean</li>
* <li>Standard deviation</li>
* </ul>
* @throws NullArgumentException if {@code observations} is
* {@code null}.
* @throws NumberIsTooSmallException if there are less than 3
* observations.
*/
public ParameterGuesser(Collection<WeightedObservedPoint> observations) {
@Override
public double[] guess(Collection<WeightedObservedPoint> observations) {
if (observations == null) {
throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
}
@ -220,68 +149,7 @@ public class GaussianCurveFitter extends AbstractCurveFitter {
}
final List<WeightedObservedPoint> sorted = sortObservations(observations);
final double[] params = basicGuess(sorted.toArray(new WeightedObservedPoint[0]));
norm = params[0];
mean = params[1];
sigma = params[2];
}
/**
* Gets an estimation of the parameters.
*
* @return the guessed parameters, in the following order:
* <ul>
* <li>Normalization factor</li>
* <li>Mean</li>
* <li>Standard deviation</li>
* </ul>
*/
public double[] guess() {
return new double[] { norm, mean, sigma };
}
/**
* Sort the observations.
*
* @param unsorted Input observations.
* @return the input observations, sorted.
*/
private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
final Comparator<WeightedObservedPoint> cmp = new Comparator<WeightedObservedPoint>() {
/** {@inheritDoc} */
@Override
public int compare(WeightedObservedPoint p1,
WeightedObservedPoint p2) {
if (p1 == null && p2 == null) {
return 0;
}
if (p1 == null) {
return -1;
}
if (p2 == null) {
return 1;
}
int comp = Double.compare(p1.getX(), p2.getX());
if (comp != 0) {
return comp;
}
comp = Double.compare(p1.getY(), p2.getY());
if (comp != 0) {
return comp;
}
comp = Double.compare(p1.getWeight(), p2.getWeight());
if (comp != 0) {
return comp;
}
return 0;
}
};
Collections.sort(observations, cmp);
return observations;
return basicGuess(sorted.toArray(new WeightedObservedPoint[0]));
}
/**
@ -309,119 +177,5 @@ public class GaussianCurveFitter extends AbstractCurveFitter {
return new double[] { n, points[maxYIdx].getX(), s };
}
/**
* Finds index of point in specified points with the largest Y.
*
* @param points Points to search.
* @return the index in specified points array.
*/
private int findMaxY(WeightedObservedPoint[] points) {
int maxYIdx = 0;
for (int i = 1; i < points.length; i++) {
if (points[i].getY() > points[maxYIdx].getY()) {
maxYIdx = i;
}
}
return maxYIdx;
}
/**
* Interpolates using the specified points to determine X at the
* specified Y.
*
* @param points Points to use for interpolation.
* @param startIdx Index within points from which to start the search for
* interpolation bounds points.
* @param idxStep Index step for searching interpolation bounds points.
* @param y Y value for which X should be determined.
* @return the value of X for the specified Y.
* @throws ZeroException if {@code idxStep} is 0.
* @throws OutOfRangeException if specified {@code y} is not within the
* range of the specified {@code points}.
*/
private double interpolateXAtY(WeightedObservedPoint[] points,
int startIdx,
int idxStep,
double y)
throws OutOfRangeException {
if (idxStep == 0) {
throw new ZeroException();
}
final WeightedObservedPoint[] twoPoints
= getInterpolationPointsForY(points, startIdx, idxStep, y);
final WeightedObservedPoint p1 = twoPoints[0];
final WeightedObservedPoint p2 = twoPoints[1];
if (p1.getY() == y) {
return p1.getX();
}
if (p2.getY() == y) {
return p2.getX();
}
return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
(p2.getY() - p1.getY()));
}
/**
* Gets the two bounding interpolation points from the specified points
* suitable for determining X at the specified Y.
*
* @param points Points to use for interpolation.
* @param startIdx Index within points from which to start search for
* interpolation bounds points.
* @param idxStep Index step for search for interpolation bounds points.
* @param y Y value for which X should be determined.
* @return the array containing two points suitable for determining X at
* the specified Y.
* @throws ZeroException if {@code idxStep} is 0.
* @throws OutOfRangeException if specified {@code y} is not within the
* range of the specified {@code points}.
*/
private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
int startIdx,
int idxStep,
double y)
throws OutOfRangeException {
if (idxStep == 0) {
throw new ZeroException();
}
for (int i = startIdx;
idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
i += idxStep) {
final WeightedObservedPoint p1 = points[i];
final WeightedObservedPoint p2 = points[i + idxStep];
if (isBetween(y, p1.getY(), p2.getY())) {
if (idxStep < 0) {
return new WeightedObservedPoint[] { p2, p1 };
} else {
return new WeightedObservedPoint[] { p1, p2 };
}
}
}
// Boundaries are replaced by dummy values because the raised
// exception is caught and the message never displayed.
// TODO: Exceptions should not be used for flow control.
throw new OutOfRangeException(y,
Double.NEGATIVE_INFINITY,
Double.POSITIVE_INFINITY);
}
/**
* Determines whether a value is between two other values.
*
* @param value Value to test whether it is between {@code boundary1}
* and {@code boundary2}.
* @param boundary1 One end of the range.
* @param boundary2 Other end of the range.
* @return {@code true} if {@code value} is between {@code boundary1} and
* {@code boundary2} (inclusive), {@code false} otherwise.
*/
private boolean isBetween(double value,
double boundary1,
double boundary2) {
return (value >= boundary1 && value <= boundary2) ||
(value >= boundary2 && value <= boundary1);
}
}
}

View File

@ -25,9 +25,6 @@ import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
import org.apache.commons.math4.legacy.exception.ZeroException;
import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
import org.apache.commons.math4.legacy.util.FastMath;
/**
@ -46,13 +43,9 @@ import org.apache.commons.math4.legacy.util.FastMath;
*
* @since 3.3
*/
public class HarmonicCurveFitter extends AbstractCurveFitter {
public class HarmonicCurveFitter extends SimpleCurveFitter {
/** Parametric function to be fitted. */
private static final HarmonicOscillator.Parametric FUNCTION = new HarmonicOscillator.Parametric();
/** Initial guess. */
private final double[] initialGuess;
/** Maximum number of iterations of the optimization algorithm. */
private final int maxIter;
/**
* Constructor used by the factory methods.
@ -63,8 +56,7 @@ public class HarmonicCurveFitter extends AbstractCurveFitter {
*/
private HarmonicCurveFitter(double[] initialGuess,
int maxIter) {
this.initialGuess = initialGuess;
this.maxIter = maxIter;
super(FUNCTION, initialGuess, new ParameterGuesser(), maxIter);
}
/**
@ -82,63 +74,6 @@ public class HarmonicCurveFitter extends AbstractCurveFitter {
return new HarmonicCurveFitter(null, Integer.MAX_VALUE);
}
/**
* Configure the start point (initial guess).
* @param newStart new start point (initial guess)
* @return a new instance.
*/
public HarmonicCurveFitter withStartPoint(double[] newStart) {
return new HarmonicCurveFitter(newStart.clone(),
maxIter);
}
/**
* Configure the maximum number of iterations.
* @param newMaxIter maximum number of iterations
* @return a new instance.
*/
public HarmonicCurveFitter withMaxIterations(int newMaxIter) {
return new HarmonicCurveFitter(initialGuess,
newMaxIter);
}
/** {@inheritDoc} */
@Override
protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
// Prepare least-squares problem.
final int len = observations.size();
final double[] target = new double[len];
final double[] weights = new double[len];
int i = 0;
for (WeightedObservedPoint obs : observations) {
target[i] = obs.getY();
weights[i] = obs.getWeight();
++i;
}
final AbstractCurveFitter.TheoreticalValuesFunction model
= new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION,
observations);
final double[] startPoint = initialGuess != null ?
initialGuess :
// Compute estimation.
new ParameterGuesser(observations).guess();
// Return a new optimizer set up to fit a Gaussian curve to the
// observed points.
return new LeastSquaresBuilder().
maxEvaluations(Integer.MAX_VALUE).
maxIterations(maxIter).
start(startPoint).
target(target).
weight(new DiagonalMatrix(weights)).
model(model.getModelFunction(), model.getModelFunctionJacobian()).
build();
}
/**
* This class guesses harmonic coefficients from a sample.
* <p>The algorithm used to guess the coefficients is as follows:</p>
@ -238,24 +173,22 @@ public class HarmonicCurveFitter extends AbstractCurveFitter {
* estimations, these operations run in \(O(n)\) time, where \(n\) is the
* number of measurements.</p>
*/
public static class ParameterGuesser {
/** Amplitude. */
private final double a;
/** Angular frequency. */
private final double omega;
/** Phase. */
private final double phi;
public static class ParameterGuesser extends SimpleCurveFitter.ParameterGuesser {
/**
* Simple constructor.
* {@inheritDoc}
*
* @param observations Sampled observations.
* @return the guessed parameters, in the following order:
* <ul>
* <li>Amplitude</li>
* <li>Angular frequency</li>
* <li>Phase</li>
* </ul>
* @throws NumberIsTooSmallException if the sample is too short.
* @throws ZeroException if the abscissa range is zero.
* @throws MathIllegalStateException when the guessing procedure cannot
* produce sensible results.
*/
public ParameterGuesser(Collection<WeightedObservedPoint> observations) {
public double[] guess(Collection<WeightedObservedPoint> observations) {
if (observations.size() < 4) {
throw new NumberIsTooSmallException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE,
observations.size(), 4, true);
@ -265,61 +198,14 @@ public class HarmonicCurveFitter extends AbstractCurveFitter {
= sortObservations(observations).toArray(new WeightedObservedPoint[0]);
final double aOmega[] = guessAOmega(sorted);
a = aOmega[0];
omega = aOmega[1];
final double a = aOmega[0];
final double omega = aOmega[1];
phi = guessPhi(sorted);
}
final double phi = guessPhi(sorted, omega);
/**
* Gets an estimation of the parameters.
*
* @return the guessed parameters, in the following order:
* <ul>
* <li>Amplitude</li>
* <li>Angular frequency</li>
* <li>Phase</li>
* </ul>
*/
public double[] guess() {
return new double[] { a, omega, phi };
}
/**
* Sort the observations with respect to the abscissa.
*
* @param unsorted Input observations.
* @return the input observations, sorted.
*/
private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
// Since the samples are almost always already sorted, this
// method is implemented as an insertion sort that reorders the
// elements in place. Insertion sort is very efficient in this case.
WeightedObservedPoint curr = observations.get(0);
final int len = observations.size();
for (int j = 1; j < len; j++) {
WeightedObservedPoint prec = curr;
curr = observations.get(j);
if (curr.getX() < prec.getX()) {
// the current element should be inserted closer to the beginning
int i = j - 1;
WeightedObservedPoint mI = observations.get(i);
while ((i >= 0) && (curr.getX() < mI.getX())) {
observations.set(i + 1, mI);
if (i-- != 0) {
mI = observations.get(i);
}
}
observations.set(i + 1, curr);
curr = observations.get(j);
}
}
return observations;
}
/**
* Estimate a first guess of the amplitude and angular frequency.
*
@ -415,9 +301,11 @@ public class HarmonicCurveFitter extends AbstractCurveFitter {
* Estimate a first guess of the phase.
*
* @param observations Observations, sorted w.r.t. abscissa.
* @param omega Angular frequency.
* @return the guessed phase.
*/
private double guessPhi(WeightedObservedPoint[] observations) {
private double guessPhi(WeightedObservedPoint[] observations,
double omega) {
// initialize the means
double fcMean = 0;
double fsMean = 0;

View File

@ -19,10 +19,6 @@ package org.apache.commons.math4.legacy.fitting;
import java.util.Collection;
import org.apache.commons.math4.legacy.analysis.polynomials.PolynomialFunction;
import org.apache.commons.math4.legacy.exception.MathInternalError;
import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
/**
* Fits points to a {@link
@ -36,25 +32,19 @@ import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
*
* @since 3.3
*/
public class PolynomialCurveFitter extends AbstractCurveFitter {
public class PolynomialCurveFitter extends SimpleCurveFitter {
/** Parametric function to be fitted. */
private static final PolynomialFunction.Parametric FUNCTION = new PolynomialFunction.Parametric();
/** Initial guess. */
private final double[] initialGuess;
/** Maximum number of iterations of the optimization algorithm. */
private final int maxIter;
/**
* Constructor used by the factory methods.
*
* @param initialGuess Initial guess.
* @param maxIter Maximum number of iterations of the optimization algorithm.
* @throws MathInternalError if {@code initialGuess} is {@code null}.
*/
private PolynomialCurveFitter(double[] initialGuess,
int maxIter) {
this.initialGuess = initialGuess;
this.maxIter = maxIter;
super(FUNCTION, initialGuess, null, maxIter);
}
/**
@ -72,60 +62,4 @@ public class PolynomialCurveFitter extends AbstractCurveFitter {
public static PolynomialCurveFitter create(int degree) {
return new PolynomialCurveFitter(new double[degree + 1], Integer.MAX_VALUE);
}
/**
* Configure the start point (initial guess).
* @param newStart new start point (initial guess)
* @return a new instance.
*/
public PolynomialCurveFitter withStartPoint(double[] newStart) {
return new PolynomialCurveFitter(newStart.clone(),
maxIter);
}
/**
* Configure the maximum number of iterations.
* @param newMaxIter maximum number of iterations
* @return a new instance.
*/
public PolynomialCurveFitter withMaxIterations(int newMaxIter) {
return new PolynomialCurveFitter(initialGuess,
newMaxIter);
}
/** {@inheritDoc} */
@Override
protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
// Prepare least-squares problem.
final int len = observations.size();
final double[] target = new double[len];
final double[] weights = new double[len];
int i = 0;
for (WeightedObservedPoint obs : observations) {
target[i] = obs.getY();
weights[i] = obs.getWeight();
++i;
}
final AbstractCurveFitter.TheoreticalValuesFunction model =
new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations);
if (initialGuess == null) {
throw new MathInternalError();
}
// Return a new least squares problem set up to fit a polynomial curve to the
// observed points.
return new LeastSquaresBuilder().
maxEvaluations(Integer.MAX_VALUE).
maxIterations(maxIter).
start(initialGuess).
target(target).
weight(new DiagonalMatrix(weights)).
model(model.getModelFunction(), model.getModelFunctionJacobian()).
build();
}
}

View File

@ -16,8 +16,14 @@
*/
package org.apache.commons.math4.legacy.fitting;
import java.util.Collections;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.ArrayList;
import org.apache.commons.math4.legacy.exception.ZeroException;
import org.apache.commons.math4.legacy.exception.OutOfRangeException;
import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
@ -33,6 +39,8 @@ public class SimpleCurveFitter extends AbstractCurveFitter {
private final ParametricUnivariateFunction function;
/** Initial guess for the parameters. */
private final double[] initialGuess;
/** Parameter guesser. */
private final ParameterGuesser guesser;
/** Maximum number of iterations of the optimization algorithm. */
private final int maxIter;
@ -42,13 +50,17 @@ public class SimpleCurveFitter extends AbstractCurveFitter {
* @param function Function to fit.
* @param initialGuess Initial guess. Cannot be {@code null}. Its length must
* be consistent with the number of parameters of the {@code function} to fit.
* @param guesser Method for providing an initial guess (if {@code initialGuess}
* is {@code null}).
* @param maxIter Maximum number of iterations of the optimization algorithm.
*/
private SimpleCurveFitter(ParametricUnivariateFunction function,
double[] initialGuess,
int maxIter) {
protected SimpleCurveFitter(ParametricUnivariateFunction function,
double[] initialGuess,
ParameterGuesser guesser,
int maxIter) {
this.function = function;
this.initialGuess = initialGuess;
this.guesser = guesser;
this.maxIter = maxIter;
}
@ -68,7 +80,24 @@ public class SimpleCurveFitter extends AbstractCurveFitter {
*/
public static SimpleCurveFitter create(ParametricUnivariateFunction f,
double[] start) {
return new SimpleCurveFitter(f, start, Integer.MAX_VALUE);
return new SimpleCurveFitter(f, start, null, Integer.MAX_VALUE);
}
/**
* Creates a curve fitter.
* The maximum number of iterations of the optimization algorithm is set
* to {@link Integer#MAX_VALUE}.
*
* @param f Function to fit.
* @param guesser Method for providing an initial guess.
* @return a curve fitter.
*
* @see #withStartPoint(double[])
* @see #withMaxIterations(int)
*/
public static SimpleCurveFitter create(ParametricUnivariateFunction f,
ParameterGuesser guesser) {
return new SimpleCurveFitter(f, null, guesser, Integer.MAX_VALUE);
}
/**
@ -79,6 +108,7 @@ public class SimpleCurveFitter extends AbstractCurveFitter {
public SimpleCurveFitter withStartPoint(double[] newStart) {
return new SimpleCurveFitter(function,
newStart.clone(),
null,
maxIter);
}
@ -90,6 +120,7 @@ public class SimpleCurveFitter extends AbstractCurveFitter {
public SimpleCurveFitter withMaxIterations(int newMaxIter) {
return new SimpleCurveFitter(function,
initialGuess,
guesser,
newMaxIter);
}
@ -112,14 +143,186 @@ public class SimpleCurveFitter extends AbstractCurveFitter {
= new AbstractCurveFitter.TheoreticalValuesFunction(function,
observations);
final double[] startPoint = initialGuess != null ?
initialGuess :
// Compute estimation.
guesser.guess(observations);
// Create an optimizer for fitting the curve to the observed points.
return new LeastSquaresBuilder().
maxEvaluations(Integer.MAX_VALUE).
maxIterations(maxIter).
start(initialGuess).
start(startPoint).
target(target).
weight(new DiagonalMatrix(weights)).
model(model.getModelFunction(), model.getModelFunctionJacobian()).
build();
}
/**
* Guesses the parameters.
*/
public static abstract class ParameterGuesser {
private final Comparator<WeightedObservedPoint> CMP = new Comparator<WeightedObservedPoint>() {
/** {@inheritDoc} */
@Override
public int compare(WeightedObservedPoint p1,
WeightedObservedPoint p2) {
if (p1 == null && p2 == null) {
return 0;
}
if (p1 == null) {
return -1;
}
if (p2 == null) {
return 1;
}
int comp = Double.compare(p1.getX(), p2.getX());
if (comp != 0) {
return comp;
}
comp = Double.compare(p1.getY(), p2.getY());
if (comp != 0) {
return comp;
}
comp = Double.compare(p1.getWeight(), p2.getWeight());
if (comp != 0) {
return comp;
}
return 0;
}
};
/**
* Computes an estimation of the parameters.
*
* @param obs Observations.
* @return the guessed parameters.
*/
public abstract double[] guess(Collection<WeightedObservedPoint> obs);
/**
* Sort the observations.
*
* @param unsorted Input observations.
* @return the input observations, sorted.
*/
protected List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
Collections.sort(observations, CMP);
return observations;
}
/**
* Finds index of point in specified points with the largest Y.
*
* @param points Points to search.
* @return the index in specified points array.
*/
protected int findMaxY(WeightedObservedPoint[] points) {
int maxYIdx = 0;
for (int i = 1; i < points.length; i++) {
if (points[i].getY() > points[maxYIdx].getY()) {
maxYIdx = i;
}
}
return maxYIdx;
}
/**
* Interpolates using the specified points to determine X at the
* specified Y.
*
* @param points Points to use for interpolation.
* @param startIdx Index within points from which to start the search for
* interpolation bounds points.
* @param idxStep Index step for searching interpolation bounds points.
* @param y Y value for which X should be determined.
* @return the value of X for the specified Y.
* @throws ZeroException if {@code idxStep} is 0.
* @throws OutOfRangeException if specified {@code y} is not within the
* range of the specified {@code points}.
*/
protected double interpolateXAtY(WeightedObservedPoint[] points,
int startIdx,
int idxStep,
double y) {
if (idxStep == 0) {
throw new ZeroException();
}
final WeightedObservedPoint[] twoPoints
= getInterpolationPointsForY(points, startIdx, idxStep, y);
final WeightedObservedPoint p1 = twoPoints[0];
final WeightedObservedPoint p2 = twoPoints[1];
if (p1.getY() == y) {
return p1.getX();
}
if (p2.getY() == y) {
return p2.getX();
}
return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
(p2.getY() - p1.getY()));
}
/**
* Gets the two bounding interpolation points from the specified points
* suitable for determining X at the specified Y.
*
* @param points Points to use for interpolation.
* @param startIdx Index within points from which to start search for
* interpolation bounds points.
* @param idxStep Index step for search for interpolation bounds points.
* @param y Y value for which X should be determined.
* @return the array containing two points suitable for determining X at
* the specified Y.
* @throws ZeroException if {@code idxStep} is 0.
* @throws OutOfRangeException if specified {@code y} is not within the
* range of the specified {@code points}.
*/
private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
int startIdx,
int idxStep,
double y) {
if (idxStep == 0) {
throw new ZeroException();
}
for (int i = startIdx;
idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
i += idxStep) {
final WeightedObservedPoint p1 = points[i];
final WeightedObservedPoint p2 = points[i + idxStep];
if (isBetween(y, p1.getY(), p2.getY())) {
if (idxStep < 0) {
return new WeightedObservedPoint[] { p2, p1 };
} else {
return new WeightedObservedPoint[] { p1, p2 };
}
}
}
// Boundaries are replaced by dummy values because the raised
// exception is caught and the message never displayed.
// TODO: Exceptions should not be used for flow control.
throw new OutOfRangeException(y,
Double.NEGATIVE_INFINITY,
Double.POSITIVE_INFINITY);
}
/**
* Determines whether a value is between two other values.
*
* @param value Value to test whether it is between {@code boundary1}
* and {@code boundary2}.
* @param boundary1 One end of the range.
* @param boundary2 Other end of the range.
* @return {@code true} if {@code value} is between {@code boundary1} and
* {@code boundary2} (inclusive), {@code false} otherwise.
*/
private boolean isBetween(double value,
double boundary1,
double boundary2) {
return (value >= boundary1 && value <= boundary2) ||
(value >= boundary2 && value <= boundary1);
}
}
}

View File

@ -180,7 +180,7 @@ public class GaussianCurveFitterTest {
*/
@Test
public void testFit01() {
GaussianCurveFitter fitter = GaussianCurveFitter.create();
SimpleCurveFitter fitter = GaussianCurveFitter.create();
double[] parameters = fitter.fit(createDataset(DATASET1).toList());
Assert.assertEquals(3496978.1837704973, parameters[0], 1e-7);
@ -190,7 +190,7 @@ public class GaussianCurveFitterTest {
@Test
public void testDataset1LargeXShift() {
final GaussianCurveFitter fitter = GaussianCurveFitter.create();
final SimpleCurveFitter fitter = GaussianCurveFitter.create();
final double xShift = 1e8;
final double[] parameters = fitter.fit(createDataset(DATASET1, xShift, 0).toList());
@ -204,7 +204,7 @@ public class GaussianCurveFitterTest {
final int maxIter = 20;
final double[] init = { 3.5e6, 4.2, 0.1 };
GaussianCurveFitter fitter = GaussianCurveFitter.create();
SimpleCurveFitter fitter = GaussianCurveFitter.create();
double[] parameters = fitter
.withMaxIterations(maxIter)
.withStartPoint(init)
@ -220,7 +220,7 @@ public class GaussianCurveFitterTest {
final int maxIter = 1; // Too few iterations.
final double[] init = { 3.5e6, 4.2, 0.1 };
GaussianCurveFitter fitter = GaussianCurveFitter.create();
SimpleCurveFitter fitter = GaussianCurveFitter.create();
fitter.withMaxIterations(maxIter)
.withStartPoint(init)
.fit(createDataset(DATASET1).toList());
@ -230,7 +230,7 @@ public class GaussianCurveFitterTest {
public void testWithStartPoint() {
final double[] init = { 3.5e6, 4.2, 0.1 };
GaussianCurveFitter fitter = GaussianCurveFitter.create();
SimpleCurveFitter fitter = GaussianCurveFitter.create();
double[] parameters = fitter
.withStartPoint(init)
.fit(createDataset(DATASET1).toList());
@ -253,7 +253,7 @@ public class GaussianCurveFitterTest {
*/
@Test(expected=MathIllegalArgumentException.class)
public void testFit03() {
GaussianCurveFitter fitter = GaussianCurveFitter.create();
SimpleCurveFitter fitter = GaussianCurveFitter.create();
fitter.fit(createDataset(new double[][] {
{4.0254623, 531026.0},
{4.02804905, 664002.0}
@ -265,7 +265,7 @@ public class GaussianCurveFitterTest {
*/
@Test
public void testFit04() {
GaussianCurveFitter fitter = GaussianCurveFitter.create();
SimpleCurveFitter fitter = GaussianCurveFitter.create();
double[] parameters = fitter.fit(createDataset(DATASET2).toList());
Assert.assertEquals(233003.2967252038, parameters[0], 1e-4);
@ -278,7 +278,7 @@ public class GaussianCurveFitterTest {
*/
@Test
public void testFit05() {
GaussianCurveFitter fitter = GaussianCurveFitter.create();
SimpleCurveFitter fitter = GaussianCurveFitter.create();
double[] parameters = fitter.fit(createDataset(DATASET3).toList());
Assert.assertEquals(283863.81929180305, parameters[0], 1e-4);
@ -291,7 +291,7 @@ public class GaussianCurveFitterTest {
*/
@Test
public void testFit06() {
GaussianCurveFitter fitter = GaussianCurveFitter.create();
SimpleCurveFitter fitter = GaussianCurveFitter.create();
double[] parameters = fitter.fit(createDataset(DATASET4).toList());
Assert.assertEquals(285250.66754309234, parameters[0], 1e-4);
@ -304,7 +304,7 @@ public class GaussianCurveFitterTest {
*/
@Test
public void testFit07() {
GaussianCurveFitter fitter = GaussianCurveFitter.create();
SimpleCurveFitter fitter = GaussianCurveFitter.create();
double[] parameters = fitter.fit(createDataset(DATASET5).toList());
Assert.assertEquals(3514384.729342235, parameters[0], 1e-4);

View File

@ -49,7 +49,7 @@ public class HarmonicCurveFitterTest {
points.add(1, x, f.value(x));
}
final HarmonicCurveFitter fitter = HarmonicCurveFitter.create();
final SimpleCurveFitter fitter = HarmonicCurveFitter.create();
final double[] fitted = fitter.fit(points.toList());
Assert.assertEquals(a, fitted[0], 1.0e-13);
Assert.assertEquals(w, fitted[1], 1.0e-13);
@ -74,7 +74,7 @@ public class HarmonicCurveFitterTest {
points.add(1, x, f.value(x) + 0.01 * randomizer.nextGaussian());
}
final HarmonicCurveFitter fitter = HarmonicCurveFitter.create();
final SimpleCurveFitter fitter = HarmonicCurveFitter.create();
final double[] fitted = fitter.fit(points.toList());
Assert.assertEquals(a, fitted[0], 7.6e-4);
Assert.assertEquals(w, fitted[1], 2.7e-3);
@ -90,7 +90,7 @@ public class HarmonicCurveFitterTest {
points.add(1, x, 1e-7 * randomizer.nextGaussian());
}
final HarmonicCurveFitter fitter = HarmonicCurveFitter.create();
final SimpleCurveFitter fitter = HarmonicCurveFitter.create();
fitter.fit(points.toList());
// This test serves to cover the part of the code of "guessAOmega"
@ -110,7 +110,7 @@ public class HarmonicCurveFitterTest {
points.add(1, x, f.value(x) + 0.01 * randomizer.nextGaussian());
}
final HarmonicCurveFitter fitter = HarmonicCurveFitter.create()
final SimpleCurveFitter fitter = HarmonicCurveFitter.create()
.withStartPoint(new double[] { 0.15, 3.6, 4.5 });
final double[] fitted = fitter.fit(points.toList());
Assert.assertEquals(a, fitted[0], 1.2e-3);
@ -153,7 +153,7 @@ public class HarmonicCurveFitterTest {
points.add(1, xTab[i], yTab[i]);
}
final HarmonicCurveFitter fitter = HarmonicCurveFitter.create();
final SimpleCurveFitter fitter = HarmonicCurveFitter.create();
final double[] fitted = fitter.fit(points.toList());
Assert.assertEquals(a, fitted[0], 7.6e-4);
Assert.assertEquals(w, fitted[1], 3.5e-3);
@ -177,6 +177,6 @@ public class HarmonicCurveFitterTest {
// and period 12, and all sample points are taken at integer abscissae
// so function values all belong to the integer subset {-3, -2, -1, 0,
// 1, 2, 3}.
new HarmonicCurveFitter.ParameterGuesser(points);
new HarmonicCurveFitter.ParameterGuesser().guess(points);
}
}

View File

@ -48,7 +48,7 @@ public class PolynomialCurveFitterTest {
}
// Start fit from initial guesses that are far from the optimal values.
final PolynomialCurveFitter fitter
final SimpleCurveFitter fitter
= PolynomialCurveFitter.create(0).withStartPoint(new double[] { -1e-20, 3e15, -5e25 });
final double[] best = fitter.fit(obs.toList());
@ -60,7 +60,7 @@ public class PolynomialCurveFitterTest {
final Random randomizer = new Random(64925784252l);
for (int degree = 1; degree < 10; ++degree) {
final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(degree);
final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
final WeightedObservedPoints obs = new WeightedObservedPoints();
for (int i = 0; i <= degree; ++i) {
@ -83,7 +83,7 @@ public class PolynomialCurveFitterTest {
double maxError = 0;
for (int degree = 0; degree < 10; ++degree) {
final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(degree);
final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
final WeightedObservedPoints obs = new WeightedObservedPoints();
for (double x = -1.0; x < 1.0; x += 0.01) {
@ -114,7 +114,7 @@ public class PolynomialCurveFitterTest {
double maxError = 0;
for (int degree = 0; degree < 10; ++degree) {
final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(degree);
final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
final WeightedObservedPoints obs = new WeightedObservedPoints();
for (int i = 0; i < 40000; ++i) {
@ -138,7 +138,7 @@ public class PolynomialCurveFitterTest {
for (int degree = 0; degree < 10; ++degree) {
final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(degree);
final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
final WeightedObservedPoints obs = new WeightedObservedPoints();
// reusing the same point over and over again does not bring
// information, the problem cannot be solved in this case for