MATH-1596: Remove dependency on "RandomVectorGenerator".

This commit is contained in:
Gilles Sadowski 2021-05-30 15:02:13 +02:00
parent c93520a02f
commit f9a7c568be
3 changed files with 22 additions and 24 deletions

View File

@ -16,10 +16,10 @@
*/ */
package org.apache.commons.math4.legacy.optim; package org.apache.commons.math4.legacy.optim;
import java.util.function.Supplier;
import org.apache.commons.math4.legacy.exception.MathIllegalStateException; import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException; import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.legacy.exception.TooManyEvaluationsException; import org.apache.commons.math4.legacy.exception.TooManyEvaluationsException;
import org.apache.commons.math4.legacy.random.RandomVectorGenerator;
/** /**
* Base class multi-start optimizer for a multivariate function. * Base class multi-start optimizer for a multivariate function.
@ -41,9 +41,9 @@ public abstract class BaseMultiStartMultivariateOptimizer<PAIR>
/** Number of evaluations already performed for all starts. */ /** Number of evaluations already performed for all starts. */
private int totalEvaluations; private int totalEvaluations;
/** Number of starts to go. */ /** Number of starts to go. */
private int starts; private final int starts;
/** Random generator for multi-start. */ /** Generator of start points ("multi-start"). */
private RandomVectorGenerator generator; private final Supplier<double[]> generator;
/** Optimization data. */ /** Optimization data. */
private OptimizationData[] optimData; private OptimizationData[] optimData;
/** /**
@ -72,12 +72,12 @@ public abstract class BaseMultiStartMultivariateOptimizer<PAIR>
* @param starts Number of starts to perform. If {@code starts == 1}, * @param starts Number of starts to perform. If {@code starts == 1},
* the {@link #optimize(OptimizationData[]) optimize} will return the * the {@link #optimize(OptimizationData[]) optimize} will return the
* same solution as the given {@code optimizer} would return. * same solution as the given {@code optimizer} would return.
* @param generator Random vector generator to use for restarts. * @param generator Generator to use for restarts.
* @throws NotStrictlyPositiveException if {@code starts < 1}. * @throws NotStrictlyPositiveException if {@code starts < 1}.
*/ */
public BaseMultiStartMultivariateOptimizer(final BaseMultivariateOptimizer<PAIR> optimizer, public BaseMultiStartMultivariateOptimizer(final BaseMultivariateOptimizer<PAIR> optimizer,
final int starts, final int starts,
final RandomVectorGenerator generator) { final Supplier<double[]> generator) {
super(optimizer.getConvergenceChecker()); super(optimizer.getConvergenceChecker());
if (starts < 1) { if (starts < 1) {
@ -185,7 +185,7 @@ public abstract class BaseMultiStartMultivariateOptimizer<PAIR>
if (attempts++ >= getMaxEvaluations()) { if (attempts++ >= getMaxEvaluations()) {
throw new TooManyEvaluationsException(getMaxEvaluations()); throw new TooManyEvaluationsException(getMaxEvaluations());
} }
s = generator.nextVector(); s = generator.get();
for (int k = 0; s != null && k < s.length; ++k) { for (int k = 0; s != null && k < s.length; ++k) {
if ((min != null && s[k] < min[k]) || (max != null && s[k] > max[k])) { if ((min != null && s[k] < min[k]) || (max != null && s[k] > max[k])) {
// reject the vector // reject the vector

View File

@ -20,12 +20,12 @@ import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.function.Supplier;
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException; import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.legacy.exception.NullArgumentException; import org.apache.commons.math4.legacy.exception.NullArgumentException;
import org.apache.commons.math4.legacy.optim.BaseMultiStartMultivariateOptimizer; import org.apache.commons.math4.legacy.optim.BaseMultiStartMultivariateOptimizer;
import org.apache.commons.math4.legacy.optim.PointValuePair; import org.apache.commons.math4.legacy.optim.PointValuePair;
import org.apache.commons.math4.legacy.random.RandomVectorGenerator;
/** /**
* Multi-start optimizer. * Multi-start optimizer.
@ -50,16 +50,14 @@ public class MultiStartMultivariateOptimizer
* @param starts Number of starts to perform. * @param starts Number of starts to perform.
* If {@code starts == 1}, the result will be same as if {@code optimizer} * If {@code starts == 1}, the result will be same as if {@code optimizer}
* is called directly. * is called directly.
* @param generator Random vector generator to use for restarts. * @param generator Generator to use for restarts.
* @throws NullArgumentException if {@code optimizer} or {@code generator} * @throws NullArgumentException if {@code optimizer} or {@code generator}
* is {@code null}. * is {@code null}.
* @throws NotStrictlyPositiveException if {@code starts < 1}. * @throws NotStrictlyPositiveException if {@code starts < 1}.
*/ */
public MultiStartMultivariateOptimizer(final MultivariateOptimizer optimizer, public MultiStartMultivariateOptimizer(final MultivariateOptimizer optimizer,
final int starts, final int starts,
final RandomVectorGenerator generator) final Supplier<double[]> generator) {
throws NullArgumentException,
NotStrictlyPositiveException {
super(optimizer, starts, generator); super(optimizer, starts, generator);
this.optimizer = optimizer; this.optimizer = optimizer;
} }

View File

@ -16,6 +16,7 @@
*/ */
package org.apache.commons.math4.legacy.optim.nonlinear.scalar; package org.apache.commons.math4.legacy.optim.nonlinear.scalar;
import java.util.function.Supplier;
import org.apache.commons.geometry.euclidean.twod.Vector2D; import org.apache.commons.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math4.legacy.analysis.MultivariateFunction; import org.apache.commons.math4.legacy.analysis.MultivariateFunction;
import org.apache.commons.math4.legacy.optim.InitialGuess; import org.apache.commons.math4.legacy.optim.InitialGuess;
@ -30,7 +31,6 @@ import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.simple.RandomSource; import org.apache.commons.rng.simple.RandomSource;
import org.apache.commons.rng.sampling.distribution.GaussianSampler; import org.apache.commons.rng.sampling.distribution.GaussianSampler;
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler; import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
import org.apache.commons.math4.legacy.random.RandomVectorGenerator;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
@ -49,9 +49,9 @@ public class MultiStartMultivariateOptimizerTest {
GradientMultivariateOptimizer underlying GradientMultivariateOptimizer underlying
= new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE, = new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE,
new SimpleValueChecker(1e-10, 1e-10)); new SimpleValueChecker(1e-10, 1e-10));
final RandomVectorGenerator generator = gaussianRandom(new double[] { 50, 50 }, final Supplier<double[]> generator = gaussianRandom(new double[] { 50, 50 },
new double[] { 10, 10 }, new double[] { 10, 10 },
RandomSource.create(RandomSource.MT_64)); RandomSource.create(RandomSource.MT_64));
int nbStarts = 10; int nbStarts = 10;
MultiStartMultivariateOptimizer optimizer MultiStartMultivariateOptimizer optimizer
= new MultiStartMultivariateOptimizer(underlying, nbStarts, generator); = new MultiStartMultivariateOptimizer(underlying, nbStarts, generator);
@ -88,9 +88,9 @@ public class MultiStartMultivariateOptimizerTest {
{ 0.9, 1.2 } , { 0.9, 1.2 } ,
{ 3.5, -2.3 } { 3.5, -2.3 }
}); });
final RandomVectorGenerator generator = gaussianRandom(new double[] { 0, 0 }, final Supplier<double[]> generator = gaussianRandom(new double[] { 0, 0 },
new double[] { 1, 1 }, new double[] { 1, 1 },
RandomSource.create(RandomSource.MT_64)); RandomSource.create(RandomSource.MT_64));
int nbStarts = 10; int nbStarts = 10;
MultiStartMultivariateOptimizer optimizer MultiStartMultivariateOptimizer optimizer
= new MultiStartMultivariateOptimizer(underlying, nbStarts, generator); = new MultiStartMultivariateOptimizer(underlying, nbStarts, generator);
@ -136,18 +136,18 @@ public class MultiStartMultivariateOptimizerTest {
* @return a random array generator where each element is a Gaussian * @return a random array generator where each element is a Gaussian
* sampling with the given mean and standard deviation. * sampling with the given mean and standard deviation.
*/ */
private RandomVectorGenerator gaussianRandom(final double[] mean, private Supplier<double[]> gaussianRandom(final double[] mean,
final double[] stdev, final double[] stdev,
final UniformRandomProvider rng) { final UniformRandomProvider rng) {
final ZigguratNormalizedGaussianSampler normalized = new ZigguratNormalizedGaussianSampler(rng); final ZigguratNormalizedGaussianSampler normalized = new ZigguratNormalizedGaussianSampler(rng);
final GaussianSampler[] samplers = new GaussianSampler[mean.length]; final GaussianSampler[] samplers = new GaussianSampler[mean.length];
for (int i = 0; i < mean.length; i++) { for (int i = 0; i < mean.length; i++) {
samplers[i] = new GaussianSampler(normalized, mean[i], stdev[i]); samplers[i] = new GaussianSampler(normalized, mean[i], stdev[i]);
} }
return new RandomVectorGenerator() { return new Supplier<double[]>() {
@Override @Override
public double[] nextVector() { public double[] get() {
final double[] s = new double[mean.length]; final double[] s = new double[mean.length];
for (int i = 0; i < mean.length; i++) { for (int i = 0; i < mean.length; i++) {
s[i] = samplers[i].sample(); s[i] = samplers[i].sample();