MATH-1596: Removed dependency on "RandomVectorGenerator".

This commit is contained in:
Gilles Sadowski 2021-05-31 03:47:57 +02:00
parent 6f4620f270
commit f24fd14718
3 changed files with 10 additions and 11 deletions

View File

@ -17,13 +17,15 @@
package org.apache.commons.math4.legacy.random; package org.apache.commons.math4.legacy.random;
import java.util.function.Supplier;
import org.apache.commons.math4.legacy.exception.DimensionMismatchException; import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
import org.apache.commons.math4.legacy.linear.RealMatrix; import org.apache.commons.math4.legacy.linear.RealMatrix;
import org.apache.commons.math4.legacy.linear.RectangularCholeskyDecomposition; import org.apache.commons.math4.legacy.linear.RectangularCholeskyDecomposition;
/** /**
* A {@link RandomVectorGenerator} that generates vectors with with * Generates vectors with with correlated components.
* correlated components. *
* <p>Random vectors with correlated components are built by combining * <p>Random vectors with correlated components are built by combining
* the uncorrelated components of another random vector in such a way that * the uncorrelated components of another random vector in such a way that
* the resulting correlations are the ones specified by a positive * the resulting correlations are the ones specified by a positive
@ -57,8 +59,7 @@ import org.apache.commons.math4.legacy.linear.RectangularCholeskyDecomposition;
* @since 1.2 * @since 1.2
*/ */
public class CorrelatedRandomVectorGenerator public class CorrelatedRandomVectorGenerator implements Supplier<double[]> {
implements RandomVectorGenerator {
/** Mean vector. */ /** Mean vector. */
private final double[] mean; private final double[] mean;
/** Underlying generator. */ /** Underlying generator. */
@ -162,8 +163,7 @@ public class CorrelatedRandomVectorGenerator
* is created at each call, the caller can do what it wants with it. * is created at each call, the caller can do what it wants with it.
*/ */
@Override @Override
public double[] nextVector() { public double[] get() {
// generate uncorrelated vector // generate uncorrelated vector
for (int i = 0; i < normalized.length; ++i) { for (int i = 0; i < normalized.length; ++i) {
normalized[i] = generator.nextNormalizedDouble(); normalized[i] = generator.nextNormalizedDouble();
@ -181,5 +181,4 @@ public class CorrelatedRandomVectorGenerator
return correlated; return correlated;
} }
} }

View File

@ -89,7 +89,7 @@ public class CorrelatedRandomVectorGeneratorTest {
double[] max = new double[mean.length]; double[] max = new double[mean.length];
Arrays.fill(max, Double.NEGATIVE_INFINITY); Arrays.fill(max, Double.NEGATIVE_INFINITY);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
double[] generated = sg.nextVector(); double[] generated = sg.get();
for (int j = 0; j < generated.length; ++j) { for (int j = 0; j < generated.length; ++j) {
min[j] = FastMath.min(min[j], generated[j]); min[j] = FastMath.min(min[j], generated[j]);
max[j] = FastMath.max(max[j], generated[j]); max[j] = FastMath.max(max[j], generated[j]);
@ -118,7 +118,7 @@ public class CorrelatedRandomVectorGeneratorTest {
VectorialMean meanStat = new VectorialMean(mean.length); VectorialMean meanStat = new VectorialMean(mean.length);
VectorialCovariance covStat = new VectorialCovariance(mean.length, true); VectorialCovariance covStat = new VectorialCovariance(mean.length, true);
for (int i = 0; i < 5000; ++i) { for (int i = 0; i < 5000; ++i) {
double[] v = generator.nextVector(); double[] v = generator.get();
meanStat.increment(v); meanStat.increment(v);
covStat.increment(v); covStat.increment(v);
} }
@ -181,7 +181,7 @@ public class CorrelatedRandomVectorGeneratorTest {
StorelessCovariance cov = new StorelessCovariance(covMatrix.length); StorelessCovariance cov = new StorelessCovariance(covMatrix.length);
for (int i = 0; i < samples; ++i) { for (int i = 0; i < samples; ++i) {
cov.increment(sampler.nextVector()); cov.increment(sampler.get());
} }
double[][] sampleCov = cov.getData(); double[][] sampleCov = cov.getData();

View File

@ -276,7 +276,7 @@ public class GLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbs
for (int i = 0; i < nModels; i++) { for (int i = 0; i < nModels; i++) {
// Generate y = xb + u with u cov // Generate y = xb + u with u cov
RealVector u = MatrixUtils.createRealVector(gen.nextVector()); RealVector u = MatrixUtils.createRealVector(gen.get());
double[] y = u.add(x.operate(b)).toArray(); double[] y = u.add(x.operate(b)).toArray();
// Estimate OLS parameters // Estimate OLS parameters