From f9a7c568bea5e17d1d1197ef330d185be21785a9 Mon Sep 17 00:00:00 2001 From: Gilles Sadowski Date: Sun, 30 May 2021 15:02:13 +0200 Subject: [PATCH] MATH-1596: Remove dependency on "RandomVectorGenerator". --- .../BaseMultiStartMultivariateOptimizer.java | 14 +++++------ .../MultiStartMultivariateOptimizer.java | 8 +++---- .../MultiStartMultivariateOptimizerTest.java | 24 +++++++++---------- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/optim/BaseMultiStartMultivariateOptimizer.java b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/optim/BaseMultiStartMultivariateOptimizer.java index 8530341fb..f4e57f3e6 100644 --- a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/optim/BaseMultiStartMultivariateOptimizer.java +++ b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/optim/BaseMultiStartMultivariateOptimizer.java @@ -16,10 +16,10 @@ */ package org.apache.commons.math4.legacy.optim; +import java.util.function.Supplier; import org.apache.commons.math4.legacy.exception.MathIllegalStateException; import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException; import org.apache.commons.math4.legacy.exception.TooManyEvaluationsException; -import org.apache.commons.math4.legacy.random.RandomVectorGenerator; /** * Base class multi-start optimizer for a multivariate function. @@ -41,9 +41,9 @@ public abstract class BaseMultiStartMultivariateOptimizer /** Number of evaluations already performed for all starts. */ private int totalEvaluations; /** Number of starts to go. */ - private int starts; - /** Random generator for multi-start. */ - private RandomVectorGenerator generator; + private final int starts; + /** Generator of start points ("multi-start"). */ + private final Supplier generator; /** Optimization data. */ private OptimizationData[] optimData; /** @@ -72,12 +72,12 @@ public abstract class BaseMultiStartMultivariateOptimizer * @param starts Number of starts to perform. If {@code starts == 1}, * the {@link #optimize(OptimizationData[]) optimize} will return the * same solution as the given {@code optimizer} would return. - * @param generator Random vector generator to use for restarts. + * @param generator Generator to use for restarts. * @throws NotStrictlyPositiveException if {@code starts < 1}. */ public BaseMultiStartMultivariateOptimizer(final BaseMultivariateOptimizer optimizer, final int starts, - final RandomVectorGenerator generator) { + final Supplier generator) { super(optimizer.getConvergenceChecker()); if (starts < 1) { @@ -185,7 +185,7 @@ public abstract class BaseMultiStartMultivariateOptimizer if (attempts++ >= getMaxEvaluations()) { throw new TooManyEvaluationsException(getMaxEvaluations()); } - s = generator.nextVector(); + s = generator.get(); for (int k = 0; s != null && k < s.length; ++k) { if ((min != null && s[k] < min[k]) || (max != null && s[k] > max[k])) { // reject the vector diff --git a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/optim/nonlinear/scalar/MultiStartMultivariateOptimizer.java b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/optim/nonlinear/scalar/MultiStartMultivariateOptimizer.java index 841dc9bb6..260185b1c 100644 --- a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/optim/nonlinear/scalar/MultiStartMultivariateOptimizer.java +++ b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/optim/nonlinear/scalar/MultiStartMultivariateOptimizer.java @@ -20,12 +20,12 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.List; +import java.util.function.Supplier; import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException; import org.apache.commons.math4.legacy.exception.NullArgumentException; import org.apache.commons.math4.legacy.optim.BaseMultiStartMultivariateOptimizer; import org.apache.commons.math4.legacy.optim.PointValuePair; -import org.apache.commons.math4.legacy.random.RandomVectorGenerator; /** * Multi-start optimizer. @@ -50,16 +50,14 @@ public class MultiStartMultivariateOptimizer * @param starts Number of starts to perform. * If {@code starts == 1}, the result will be same as if {@code optimizer} * is called directly. - * @param generator Random vector generator to use for restarts. + * @param generator Generator to use for restarts. * @throws NullArgumentException if {@code optimizer} or {@code generator} * is {@code null}. * @throws NotStrictlyPositiveException if {@code starts < 1}. */ public MultiStartMultivariateOptimizer(final MultivariateOptimizer optimizer, final int starts, - final RandomVectorGenerator generator) - throws NullArgumentException, - NotStrictlyPositiveException { + final Supplier generator) { super(optimizer, starts, generator); this.optimizer = optimizer; } diff --git a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/optim/nonlinear/scalar/MultiStartMultivariateOptimizerTest.java b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/optim/nonlinear/scalar/MultiStartMultivariateOptimizerTest.java index 96b228e77..f7fc08c29 100644 --- a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/optim/nonlinear/scalar/MultiStartMultivariateOptimizerTest.java +++ b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/optim/nonlinear/scalar/MultiStartMultivariateOptimizerTest.java @@ -16,6 +16,7 @@ */ package org.apache.commons.math4.legacy.optim.nonlinear.scalar; +import java.util.function.Supplier; import org.apache.commons.geometry.euclidean.twod.Vector2D; import org.apache.commons.math4.legacy.analysis.MultivariateFunction; import org.apache.commons.math4.legacy.optim.InitialGuess; @@ -30,7 +31,6 @@ import org.apache.commons.rng.UniformRandomProvider; import org.apache.commons.rng.simple.RandomSource; import org.apache.commons.rng.sampling.distribution.GaussianSampler; import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler; -import org.apache.commons.math4.legacy.random.RandomVectorGenerator; import org.junit.Assert; import org.junit.Test; @@ -49,9 +49,9 @@ public class MultiStartMultivariateOptimizerTest { GradientMultivariateOptimizer underlying = new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE, new SimpleValueChecker(1e-10, 1e-10)); - final RandomVectorGenerator generator = gaussianRandom(new double[] { 50, 50 }, - new double[] { 10, 10 }, - RandomSource.create(RandomSource.MT_64)); + final Supplier generator = gaussianRandom(new double[] { 50, 50 }, + new double[] { 10, 10 }, + RandomSource.create(RandomSource.MT_64)); int nbStarts = 10; MultiStartMultivariateOptimizer optimizer = new MultiStartMultivariateOptimizer(underlying, nbStarts, generator); @@ -88,9 +88,9 @@ public class MultiStartMultivariateOptimizerTest { { 0.9, 1.2 } , { 3.5, -2.3 } }); - final RandomVectorGenerator generator = gaussianRandom(new double[] { 0, 0 }, - new double[] { 1, 1 }, - RandomSource.create(RandomSource.MT_64)); + final Supplier generator = gaussianRandom(new double[] { 0, 0 }, + new double[] { 1, 1 }, + RandomSource.create(RandomSource.MT_64)); int nbStarts = 10; MultiStartMultivariateOptimizer optimizer = new MultiStartMultivariateOptimizer(underlying, nbStarts, generator); @@ -136,18 +136,18 @@ public class MultiStartMultivariateOptimizerTest { * @return a random array generator where each element is a Gaussian * sampling with the given mean and standard deviation. */ - private RandomVectorGenerator gaussianRandom(final double[] mean, - final double[] stdev, - final UniformRandomProvider rng) { + private Supplier gaussianRandom(final double[] mean, + final double[] stdev, + final UniformRandomProvider rng) { final ZigguratNormalizedGaussianSampler normalized = new ZigguratNormalizedGaussianSampler(rng); final GaussianSampler[] samplers = new GaussianSampler[mean.length]; for (int i = 0; i < mean.length; i++) { samplers[i] = new GaussianSampler(normalized, mean[i], stdev[i]); } - return new RandomVectorGenerator() { + return new Supplier() { @Override - public double[] nextVector() { + public double[] get() { final double[] s = new double[mean.length]; for (int i = 0; i < mean.length; i++) { s[i] = samplers[i].sample();