diff --git a/src/main/java/org/apache/commons/math/optimization/direct/CMAESOptimizer.java b/src/main/java/org/apache/commons/math/optimization/direct/CMAESOptimizer.java index 08651635d..3f2a764cf 100644 --- a/src/main/java/org/apache/commons/math/optimization/direct/CMAESOptimizer.java +++ b/src/main/java/org/apache/commons/math/optimization/direct/CMAESOptimizer.java @@ -37,6 +37,7 @@ import org.apache.commons.math.optimization.MultivariateRealOptimizer; import org.apache.commons.math.optimization.RealPointValuePair; import org.apache.commons.math.random.MersenneTwister; import org.apache.commons.math.random.RandomGenerator; +import org.apache.commons.math.util.MathUtils; /** * CMA-ES algorithm. This code is translated and adapted from the Matlab version @@ -386,9 +387,9 @@ public class CMAESOptimizer extends int[] arindex = sortedIndices(fitness); // Calculate new xmean, this is selection and recombination RealMatrix xold = xmean; // for speed up of Eq. (2) and (3) - RealMatrix bestArx = selectColumns(arx,Arrays.copyOf(arindex, mu)); + RealMatrix bestArx = selectColumns(arx, MathUtils.copyOf(arindex, mu)); xmean = bestArx.multiply(weights); - RealMatrix bestArz = selectColumns(arz,Arrays.copyOf(arindex, mu)); + RealMatrix bestArz = selectColumns(arz, MathUtils.copyOf(arindex, mu)); RealMatrix zmean = bestArz.multiply(weights); boolean hsig = updateEvolutionPaths(zmean, xold); if (diagonalOnly <= 0) @@ -678,8 +679,8 @@ public class CMAESOptimizer extends // loss, // prepare vectors, compute negative updating matrix Cneg int[] arReverseIndex = reverse(arindex); - RealMatrix arzneg = selectColumns(arz, - Arrays.copyOf(arReverseIndex, mu)); + RealMatrix arzneg + = selectColumns(arz, MathUtils.copyOf(arReverseIndex, mu)); RealMatrix arnorms = sqrt(sumRows(square(arzneg))); int[] idxnorms = sortedIndices(arnorms.getRow(0)); RealMatrix arnormsSorted = selectColumns(arnorms, idxnorms); diff --git a/src/main/java/org/apache/commons/math/util/MathUtils.java b/src/main/java/org/apache/commons/math/util/MathUtils.java index 59e4b1c79..0a56b9945 100644 --- a/src/main/java/org/apache/commons/math/util/MathUtils.java +++ b/src/main/java/org/apache/commons/math/util/MathUtils.java @@ -2220,10 +2220,7 @@ public final class MathUtils { * @return the copied array. */ public static int[] copyOf(int[] source) { - final int len = source.length; - final int[] output = new int[len]; - System.arraycopy(source, 0, output, 0, len); - return output; + return copyOf(source, source.length); } /** @@ -2233,9 +2230,36 @@ public final class MathUtils { * @return the copied array. */ public static double[] copyOf(double[] source) { - final int len = source.length; + return copyOf(source, source.length); + } + + /** + * Creates a copy of the {@code source} array. + * + * @param source Array to be copied. + * @param len Number of entries to copy. If smaller then the source + * length, the copy will be truncated, if larger it will padded with + * zeroes. + * @return the copied array. + */ + public static int[] copyOf(int[] source, int len) { + final int[] output = new int[len]; + System.arraycopy(source, 0, output, 0, FastMath.min(len, source.length)); + return output; + } + + /** + * Creates a copy of the {@code source} array. + * + * @param source Array to be copied. + * @param len Number of entries to copy. If smaller then the source + * length, the copy will be truncated, if larger it will padded with + * zeroes. + * @return the copied array. + */ + public static double[] copyOf(double[] source, int len) { final double[] output = new double[len]; - System.arraycopy(source, 0, output, 0, len); + System.arraycopy(source, 0, output, 0, FastMath.min(len, source.length)); return output; } } diff --git a/src/test/java/org/apache/commons/math/util/MathUtilsTest.java b/src/test/java/org/apache/commons/math/util/MathUtilsTest.java index c5787770e..5b49797e7 100644 --- a/src/test/java/org/apache/commons/math/util/MathUtilsTest.java +++ b/src/test/java/org/apache/commons/math/util/MathUtilsTest.java @@ -1557,6 +1557,35 @@ public final class MathUtilsTest extends TestCase { } } + public void testCopyOfInt2() { + final int[] source = { Integer.MIN_VALUE, + -1, 0, 1, 3, 113, 4769, + Integer.MAX_VALUE }; + final int offset = 3; + final int[] dest = MathUtils.copyOf(source, source.length - offset); + + assertEquals(dest.length, source.length - offset); + for (int i = 0; i < source.length - offset; i++) { + assertEquals(source[i], dest[i]); + } + } + + public void testCopyOfInt3() { + final int[] source = { Integer.MIN_VALUE, + -1, 0, 1, 3, 113, 4769, + Integer.MAX_VALUE }; + final int offset = 3; + final int[] dest = MathUtils.copyOf(source, source.length + offset); + + assertEquals(dest.length, source.length + offset); + for (int i = 0; i < source.length; i++) { + assertEquals(source[i], dest[i]); + } + for (int i = source.length; i < source.length + offset; i++) { + assertEquals(0, dest[i], 0); + } + } + public void testCopyOfDouble() { final double[] source = { Double.NEGATIVE_INFINITY, -Double.MAX_VALUE, @@ -1573,4 +1602,43 @@ public final class MathUtilsTest extends TestCase { assertEquals(source[i], dest[i], 0); } } + + public void testCopyOfDouble2() { + final double[] source = { Double.NEGATIVE_INFINITY, + -Double.MAX_VALUE, + -1, 0, + Double.MIN_VALUE, + Math.ulp(1d), + 1, 3, 113, 4769, + Double.MAX_VALUE, + Double.POSITIVE_INFINITY }; + final int offset = 3; + final double[] dest = MathUtils.copyOf(source, source.length - offset); + + assertEquals(dest.length, source.length - offset); + for (int i = 0; i < source.length - offset; i++) { + assertEquals(source[i], dest[i], 0); + } + } + + public void testCopyOfDouble3() { + final double[] source = { Double.NEGATIVE_INFINITY, + -Double.MAX_VALUE, + -1, 0, + Double.MIN_VALUE, + Math.ulp(1d), + 1, 3, 113, 4769, + Double.MAX_VALUE, + Double.POSITIVE_INFINITY }; + final int offset = 3; + final double[] dest = MathUtils.copyOf(source, source.length + offset); + + assertEquals(dest.length, source.length + offset); + for (int i = 0; i < source.length; i++) { + assertEquals(source[i], dest[i], 0); + } + for (int i = source.length; i < source.length + offset; i++) { + assertEquals(0, dest[i], 0); + } + } }