MATH-1596: Remove dependency on "RandomVectorGenerator".
This commit is contained in:
parent
c93520a02f
commit
f9a7c568be
|
@ -16,10 +16,10 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.commons.math4.legacy.optim;
|
package org.apache.commons.math4.legacy.optim;
|
||||||
|
|
||||||
|
import java.util.function.Supplier;
|
||||||
import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
|
import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
|
||||||
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
|
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
|
||||||
import org.apache.commons.math4.legacy.exception.TooManyEvaluationsException;
|
import org.apache.commons.math4.legacy.exception.TooManyEvaluationsException;
|
||||||
import org.apache.commons.math4.legacy.random.RandomVectorGenerator;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base class multi-start optimizer for a multivariate function.
|
* Base class multi-start optimizer for a multivariate function.
|
||||||
|
@ -41,9 +41,9 @@ public abstract class BaseMultiStartMultivariateOptimizer<PAIR>
|
||||||
/** Number of evaluations already performed for all starts. */
|
/** Number of evaluations already performed for all starts. */
|
||||||
private int totalEvaluations;
|
private int totalEvaluations;
|
||||||
/** Number of starts to go. */
|
/** Number of starts to go. */
|
||||||
private int starts;
|
private final int starts;
|
||||||
/** Random generator for multi-start. */
|
/** Generator of start points ("multi-start"). */
|
||||||
private RandomVectorGenerator generator;
|
private final Supplier<double[]> generator;
|
||||||
/** Optimization data. */
|
/** Optimization data. */
|
||||||
private OptimizationData[] optimData;
|
private OptimizationData[] optimData;
|
||||||
/**
|
/**
|
||||||
|
@ -72,12 +72,12 @@ public abstract class BaseMultiStartMultivariateOptimizer<PAIR>
|
||||||
* @param starts Number of starts to perform. If {@code starts == 1},
|
* @param starts Number of starts to perform. If {@code starts == 1},
|
||||||
* the {@link #optimize(OptimizationData[]) optimize} will return the
|
* the {@link #optimize(OptimizationData[]) optimize} will return the
|
||||||
* same solution as the given {@code optimizer} would return.
|
* same solution as the given {@code optimizer} would return.
|
||||||
* @param generator Random vector generator to use for restarts.
|
* @param generator Generator to use for restarts.
|
||||||
* @throws NotStrictlyPositiveException if {@code starts < 1}.
|
* @throws NotStrictlyPositiveException if {@code starts < 1}.
|
||||||
*/
|
*/
|
||||||
public BaseMultiStartMultivariateOptimizer(final BaseMultivariateOptimizer<PAIR> optimizer,
|
public BaseMultiStartMultivariateOptimizer(final BaseMultivariateOptimizer<PAIR> optimizer,
|
||||||
final int starts,
|
final int starts,
|
||||||
final RandomVectorGenerator generator) {
|
final Supplier<double[]> generator) {
|
||||||
super(optimizer.getConvergenceChecker());
|
super(optimizer.getConvergenceChecker());
|
||||||
|
|
||||||
if (starts < 1) {
|
if (starts < 1) {
|
||||||
|
@ -185,7 +185,7 @@ public abstract class BaseMultiStartMultivariateOptimizer<PAIR>
|
||||||
if (attempts++ >= getMaxEvaluations()) {
|
if (attempts++ >= getMaxEvaluations()) {
|
||||||
throw new TooManyEvaluationsException(getMaxEvaluations());
|
throw new TooManyEvaluationsException(getMaxEvaluations());
|
||||||
}
|
}
|
||||||
s = generator.nextVector();
|
s = generator.get();
|
||||||
for (int k = 0; s != null && k < s.length; ++k) {
|
for (int k = 0; s != null && k < s.length; ++k) {
|
||||||
if ((min != null && s[k] < min[k]) || (max != null && s[k] > max[k])) {
|
if ((min != null && s[k] < min[k]) || (max != null && s[k] > max[k])) {
|
||||||
// reject the vector
|
// reject the vector
|
||||||
|
|
|
@ -20,12 +20,12 @@ import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.function.Supplier;
|
||||||
|
|
||||||
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
|
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
|
||||||
import org.apache.commons.math4.legacy.exception.NullArgumentException;
|
import org.apache.commons.math4.legacy.exception.NullArgumentException;
|
||||||
import org.apache.commons.math4.legacy.optim.BaseMultiStartMultivariateOptimizer;
|
import org.apache.commons.math4.legacy.optim.BaseMultiStartMultivariateOptimizer;
|
||||||
import org.apache.commons.math4.legacy.optim.PointValuePair;
|
import org.apache.commons.math4.legacy.optim.PointValuePair;
|
||||||
import org.apache.commons.math4.legacy.random.RandomVectorGenerator;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Multi-start optimizer.
|
* Multi-start optimizer.
|
||||||
|
@ -50,16 +50,14 @@ public class MultiStartMultivariateOptimizer
|
||||||
* @param starts Number of starts to perform.
|
* @param starts Number of starts to perform.
|
||||||
* If {@code starts == 1}, the result will be same as if {@code optimizer}
|
* If {@code starts == 1}, the result will be same as if {@code optimizer}
|
||||||
* is called directly.
|
* is called directly.
|
||||||
* @param generator Random vector generator to use for restarts.
|
* @param generator Generator to use for restarts.
|
||||||
* @throws NullArgumentException if {@code optimizer} or {@code generator}
|
* @throws NullArgumentException if {@code optimizer} or {@code generator}
|
||||||
* is {@code null}.
|
* is {@code null}.
|
||||||
* @throws NotStrictlyPositiveException if {@code starts < 1}.
|
* @throws NotStrictlyPositiveException if {@code starts < 1}.
|
||||||
*/
|
*/
|
||||||
public MultiStartMultivariateOptimizer(final MultivariateOptimizer optimizer,
|
public MultiStartMultivariateOptimizer(final MultivariateOptimizer optimizer,
|
||||||
final int starts,
|
final int starts,
|
||||||
final RandomVectorGenerator generator)
|
final Supplier<double[]> generator) {
|
||||||
throws NullArgumentException,
|
|
||||||
NotStrictlyPositiveException {
|
|
||||||
super(optimizer, starts, generator);
|
super(optimizer, starts, generator);
|
||||||
this.optimizer = optimizer;
|
this.optimizer = optimizer;
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.commons.math4.legacy.optim.nonlinear.scalar;
|
package org.apache.commons.math4.legacy.optim.nonlinear.scalar;
|
||||||
|
|
||||||
|
import java.util.function.Supplier;
|
||||||
import org.apache.commons.geometry.euclidean.twod.Vector2D;
|
import org.apache.commons.geometry.euclidean.twod.Vector2D;
|
||||||
import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
|
import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
|
||||||
import org.apache.commons.math4.legacy.optim.InitialGuess;
|
import org.apache.commons.math4.legacy.optim.InitialGuess;
|
||||||
|
@ -30,7 +31,6 @@ import org.apache.commons.rng.UniformRandomProvider;
|
||||||
import org.apache.commons.rng.simple.RandomSource;
|
import org.apache.commons.rng.simple.RandomSource;
|
||||||
import org.apache.commons.rng.sampling.distribution.GaussianSampler;
|
import org.apache.commons.rng.sampling.distribution.GaussianSampler;
|
||||||
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
|
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
|
||||||
import org.apache.commons.math4.legacy.random.RandomVectorGenerator;
|
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
@ -49,9 +49,9 @@ public class MultiStartMultivariateOptimizerTest {
|
||||||
GradientMultivariateOptimizer underlying
|
GradientMultivariateOptimizer underlying
|
||||||
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
|
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
|
||||||
new SimpleValueChecker(1e-10, 1e-10));
|
new SimpleValueChecker(1e-10, 1e-10));
|
||||||
final RandomVectorGenerator generator = gaussianRandom(new double[] { 50, 50 },
|
final Supplier<double[]> generator = gaussianRandom(new double[] { 50, 50 },
|
||||||
new double[] { 10, 10 },
|
new double[] { 10, 10 },
|
||||||
RandomSource.create(RandomSource.MT_64));
|
RandomSource.create(RandomSource.MT_64));
|
||||||
int nbStarts = 10;
|
int nbStarts = 10;
|
||||||
MultiStartMultivariateOptimizer optimizer
|
MultiStartMultivariateOptimizer optimizer
|
||||||
= new MultiStartMultivariateOptimizer(underlying, nbStarts, generator);
|
= new MultiStartMultivariateOptimizer(underlying, nbStarts, generator);
|
||||||
|
@ -88,9 +88,9 @@ public class MultiStartMultivariateOptimizerTest {
|
||||||
{ 0.9, 1.2 } ,
|
{ 0.9, 1.2 } ,
|
||||||
{ 3.5, -2.3 }
|
{ 3.5, -2.3 }
|
||||||
});
|
});
|
||||||
final RandomVectorGenerator generator = gaussianRandom(new double[] { 0, 0 },
|
final Supplier<double[]> generator = gaussianRandom(new double[] { 0, 0 },
|
||||||
new double[] { 1, 1 },
|
new double[] { 1, 1 },
|
||||||
RandomSource.create(RandomSource.MT_64));
|
RandomSource.create(RandomSource.MT_64));
|
||||||
int nbStarts = 10;
|
int nbStarts = 10;
|
||||||
MultiStartMultivariateOptimizer optimizer
|
MultiStartMultivariateOptimizer optimizer
|
||||||
= new MultiStartMultivariateOptimizer(underlying, nbStarts, generator);
|
= new MultiStartMultivariateOptimizer(underlying, nbStarts, generator);
|
||||||
|
@ -136,18 +136,18 @@ public class MultiStartMultivariateOptimizerTest {
|
||||||
* @return a random array generator where each element is a Gaussian
|
* @return a random array generator where each element is a Gaussian
|
||||||
* sampling with the given mean and standard deviation.
|
* sampling with the given mean and standard deviation.
|
||||||
*/
|
*/
|
||||||
private RandomVectorGenerator gaussianRandom(final double[] mean,
|
private Supplier<double[]> gaussianRandom(final double[] mean,
|
||||||
final double[] stdev,
|
final double[] stdev,
|
||||||
final UniformRandomProvider rng) {
|
final UniformRandomProvider rng) {
|
||||||
final ZigguratNormalizedGaussianSampler normalized = new ZigguratNormalizedGaussianSampler(rng);
|
final ZigguratNormalizedGaussianSampler normalized = new ZigguratNormalizedGaussianSampler(rng);
|
||||||
final GaussianSampler[] samplers = new GaussianSampler[mean.length];
|
final GaussianSampler[] samplers = new GaussianSampler[mean.length];
|
||||||
for (int i = 0; i < mean.length; i++) {
|
for (int i = 0; i < mean.length; i++) {
|
||||||
samplers[i] = new GaussianSampler(normalized, mean[i], stdev[i]);
|
samplers[i] = new GaussianSampler(normalized, mean[i], stdev[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
return new RandomVectorGenerator() {
|
return new Supplier<double[]>() {
|
||||||
@Override
|
@Override
|
||||||
public double[] nextVector() {
|
public double[] get() {
|
||||||
final double[] s = new double[mean.length];
|
final double[] s = new double[mean.length];
|
||||||
for (int i = 0; i < mean.length; i++) {
|
for (int i = 0; i < mean.length; i++) {
|
||||||
s[i] = samplers[i].sample();
|
s[i] = samplers[i].sample();
|
||||||
|
|
Loading…
Reference in New Issue