diff --git a/src/main/java/org/apache/commons/math3/analysis/differentiation/DerivativeStructure.java b/src/main/java/org/apache/commons/math3/analysis/differentiation/DerivativeStructure.java index 21ee81140..1e5b6db11 100644 --- a/src/main/java/org/apache/commons/math3/analysis/differentiation/DerivativeStructure.java +++ b/src/main/java/org/apache/commons/math3/analysis/differentiation/DerivativeStructure.java @@ -386,6 +386,22 @@ public class DerivativeStructure implements FieldElement, S FastMath.floor(data[0])); } + /** + * Returns the instance with the sign of the argument. + * A NaN {@code sign} argument is treated as positive. + * + * @param sign the sign for the returned value + * @return the instance with the same sign as the {@code sign} argument + */ + public DerivativeStructure copySign(final double sign){ + long m = Double.doubleToLongBits(data[0]); + long s = Double.doubleToLongBits(sign); + if ((m >= 0 && s >= 0) || (m < 0 && s < 0)) { // Sign is currently OK + return this; + } + return negate(); // flip sign + } + /** {@inheritDoc} */ public DerivativeStructure reciprocal() { final DerivativeStructure result = new DerivativeStructure(compiler); diff --git a/src/test/java/org/apache/commons/math3/analysis/differentiation/DerivativeStructureTest.java b/src/test/java/org/apache/commons/math3/analysis/differentiation/DerivativeStructureTest.java index 152a39e18..7057a2112 100644 --- a/src/test/java/org/apache/commons/math3/analysis/differentiation/DerivativeStructureTest.java +++ b/src/test/java/org/apache/commons/math3/analysis/differentiation/DerivativeStructureTest.java @@ -765,6 +765,21 @@ public class DerivativeStructureTest { } + @Test + public void testCopySign() { + DerivativeStructure minusOne = new DerivativeStructure(1, 1, 0, -1.0); + Assert.assertEquals(+1.0, minusOne.copySign(+1.0).getPartialDerivative(0), 1.0e-15); + Assert.assertEquals(-1.0, minusOne.copySign(+1.0).getPartialDerivative(1), 1.0e-15); + Assert.assertEquals(-1.0, minusOne.copySign(-1.0).getPartialDerivative(0), 1.0e-15); + Assert.assertEquals(+1.0, minusOne.copySign(-1.0).getPartialDerivative(1), 1.0e-15); + Assert.assertEquals(+1.0, minusOne.copySign(+0.0).getPartialDerivative(0), 1.0e-15); + Assert.assertEquals(-1.0, minusOne.copySign(+0.0).getPartialDerivative(1), 1.0e-15); + Assert.assertEquals(-1.0, minusOne.copySign(-0.0).getPartialDerivative(0), 1.0e-15); + Assert.assertEquals(+1.0, minusOne.copySign(-0.0).getPartialDerivative(1), 1.0e-15); + Assert.assertEquals(+1.0, minusOne.copySign(Double.NaN).getPartialDerivative(0), 1.0e-15); + Assert.assertEquals(-1.0, minusOne.copySign(Double.NaN).getPartialDerivative(1), 1.0e-15); + } + @Test public void testField() { for (int maxOrder = 1; maxOrder < 5; ++maxOrder) {