Fix FFT Test to use the expected imaginary result for relative error

Update to JUnit 5 Assertions.

Add absolute tolerance check in addition to relative tolerance.

Use Precision for equality checks.

Add relative/abs error in the assertion failure message.
This commit is contained in:
aherbert 2022-06-07 14:24:41 +01:00
parent f37494baca
commit 9293da2a06
2 changed files with 100 additions and 113 deletions

View File

@ -17,14 +17,13 @@
package org.apache.commons.math4.transform; package org.apache.commons.math4.transform;
import java.util.function.DoubleUnaryOperator; import java.util.function.DoubleUnaryOperator;
import java.util.function.Supplier;
import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.junit.jupiter.api.Assertions;
import org.apache.commons.rng.UniformRandomProvider; import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.simple.RandomSource; import org.apache.commons.rng.simple.RandomSource;
import org.apache.commons.numbers.complex.Complex; import org.apache.commons.numbers.complex.Complex;
import org.apache.commons.numbers.core.Precision;
import org.apache.commons.math3.analysis.UnivariateFunction; import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.function.Sin; import org.apache.commons.math3.analysis.function.Sin;
import org.apache.commons.math3.analysis.function.Sinc; import org.apache.commons.math3.analysis.function.Sinc;
@ -42,23 +41,20 @@ public final class FastFourierTransformerTest {
/** RNG. */ /** RNG. */
private static final UniformRandomProvider RNG = RandomSource.MWC_256.create(); private static final UniformRandomProvider RNG = RandomSource.MWC_256.create();
/** Minimum relative error epsilon. Can be used a small absolute delta epsilon. */
private static final double EPSILON = Math.ulp(1.0);
// Precondition checks. // Precondition checks.
@Test @Test
public void testTransformComplexSizeNotAPowerOfTwo() { public void testTransformComplexSizeNotAPowerOfTwo() {
final int n = 127; final int n = 127;
final Complex[] x = createComplexData(n); final Complex[] x = createComplexData(n);
final FastFourierTransform.Norm[] norm = FastFourierTransform.Norm.values(); for (FastFourierTransform.Norm norm : FastFourierTransform.Norm.values()) {
for (int i = 0; i < norm.length; i++) {
for (boolean type : new boolean[] {true, false}) { for (boolean type : new boolean[] {true, false}) {
final FastFourierTransform fft = new FastFourierTransform(norm[i], type); final FastFourierTransform fft = new FastFourierTransform(norm, type);
try { Assertions.assertThrows(IllegalArgumentException.class, () -> fft.apply(x),
fft.apply(x); () -> norm + ", " + type);
Assert.fail(norm[i] + ", " + type +
": IllegalArgumentException was expected");
} catch (IllegalArgumentException e) {
// Expected behaviour
}
} }
} }
} }
@ -67,17 +63,11 @@ public final class FastFourierTransformerTest {
public void testTransformRealSizeNotAPowerOfTwo() { public void testTransformRealSizeNotAPowerOfTwo() {
final int n = 127; final int n = 127;
final double[] x = createRealData(n); final double[] x = createRealData(n);
final FastFourierTransform.Norm[] norm = FastFourierTransform.Norm.values(); for (FastFourierTransform.Norm norm : FastFourierTransform.Norm.values()) {
for (int i = 0; i < norm.length; i++) {
for (boolean type : new boolean[] {true, false}) { for (boolean type : new boolean[] {true, false}) {
final FastFourierTransform fft = new FastFourierTransform(norm[i], type); final FastFourierTransform fft = new FastFourierTransform(norm, type);
try { Assertions.assertThrows(IllegalArgumentException.class, () -> fft.apply(x),
fft.apply(x); () -> norm + ", " + type);
Assert.fail(norm[i] + ", " + type +
": IllegalArgumentException was expected");
} catch (IllegalArgumentException e) {
// Expected behaviour
}
} }
} }
} }
@ -85,17 +75,11 @@ public final class FastFourierTransformerTest {
@Test @Test
public void testTransformFunctionSizeNotAPowerOfTwo() { public void testTransformFunctionSizeNotAPowerOfTwo() {
final int n = 127; final int n = 127;
final FastFourierTransform.Norm[] norm = FastFourierTransform.Norm.values(); for (FastFourierTransform.Norm norm : FastFourierTransform.Norm.values()) {
for (int i = 0; i < norm.length; i++) {
for (boolean type : new boolean[] {true, false}) { for (boolean type : new boolean[] {true, false}) {
final FastFourierTransform fft = new FastFourierTransform(norm[i], type); final FastFourierTransform fft = new FastFourierTransform(norm, type);
try { Assertions.assertThrows(IllegalArgumentException.class, () -> fft.apply(SIN, 0.0, Math.PI, n),
fft.apply(SIN, 0.0, Math.PI, n); () -> norm + ", " + type);
Assert.fail(norm[i] + ", " + type +
": IllegalArgumentException was expected");
} catch (IllegalArgumentException e) {
// Expected behaviour
}
} }
} }
} }
@ -103,18 +87,11 @@ public final class FastFourierTransformerTest {
@Test @Test
public void testTransformFunctionNotStrictlyPositiveNumberOfSamples() { public void testTransformFunctionNotStrictlyPositiveNumberOfSamples() {
final int n = -128; final int n = -128;
final FastFourierTransform.Norm[] norm = FastFourierTransform.Norm.values(); for (FastFourierTransform.Norm norm : FastFourierTransform.Norm.values()) {
for (int i = 0; i < norm.length; i++) {
for (boolean type : new boolean[] {true, false}) { for (boolean type : new boolean[] {true, false}) {
final FastFourierTransform fft = new FastFourierTransform(norm[i], type); final FastFourierTransform fft = new FastFourierTransform(norm, type);
try { Assertions.assertThrows(IllegalArgumentException.class, () -> fft.apply(SIN, 0.0, Math.PI, n),
fft.apply(SIN, 0.0, Math.PI, n); () -> norm + ", " + type);
fft.apply(SIN, 0.0, Math.PI, n);
Assert.fail(norm[i] + ", " + type +
": IllegalArgumentException was expected");
} catch (IllegalArgumentException e) {
// Expected behaviour
}
} }
} }
} }
@ -122,17 +99,11 @@ public final class FastFourierTransformerTest {
@Test @Test
public void testTransformFunctionInvalidBounds() { public void testTransformFunctionInvalidBounds() {
final int n = 128; final int n = 128;
final FastFourierTransform.Norm[] norm = FastFourierTransform.Norm.values(); for (FastFourierTransform.Norm norm : FastFourierTransform.Norm.values()) {
for (int i = 0; i < norm.length; i++) {
for (boolean type : new boolean[] {true, false}) { for (boolean type : new boolean[] {true, false}) {
final FastFourierTransform fft = new FastFourierTransform(norm[i], type); final FastFourierTransform fft = new FastFourierTransform(norm, type);
try { Assertions.assertThrows(IllegalArgumentException.class, () -> fft.apply(SIN, Math.PI, 0.0, n),
fft.apply(SIN, Math.PI, 0.0, n); () -> norm + ", " + type);
Assert.fail(norm[i] + ", " + type +
": IllegalArgumentException was expected");
} catch (IllegalArgumentException e) {
// Expected behaviour
}
} }
} }
} }
@ -187,6 +158,7 @@ public final class FastFourierTransformerTest {
private static void doTestTransformComplex(final int n, private static void doTestTransformComplex(final int n,
final double tol, final double tol,
final double absTol,
final FastFourierTransform.Norm normalization, final FastFourierTransform.Norm normalization,
boolean inverse) { boolean inverse) {
final FastFourierTransform fft = new FastFourierTransform(normalization, inverse); final FastFourierTransform fft = new FastFourierTransform(normalization, inverse);
@ -210,19 +182,19 @@ public final class FastFourierTransformerTest {
} }
final Complex[] actual = fft.apply(x); final Complex[] actual = fft.apply(x);
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
final String msg; final int index = i;
msg = String.format("%s, %s, %d, %d", normalization, inverse, n, i);
final double re = s * expected[i].getReal(); final double re = s * expected[i].getReal();
Assert.assertEquals(msg, re, actual[i].getReal(), assertEqualsRelativeOrAbsolute(re, actual[i].getReal(), tol, absTol,
tol * Math.abs(re)); () -> String.format("%s, %s, %d, %d", normalization, inverse, n, index));
final double im = s * expected[i].getImaginary(); final double im = s * expected[i].getImaginary();
Assert.assertEquals(msg, im, actual[i].getImaginary(), assertEqualsRelativeOrAbsolute(im, actual[i].getImaginary(), tol, absTol,
tol * Math.abs(re)); () -> String.format("%s, %s, %d, %d", normalization, inverse, n, index));
} }
} }
private static void doTestTransformReal(final int n, private static void doTestTransformReal(final int n,
final double tol, final double tol,
final double absTol,
final FastFourierTransform.Norm normalization, final FastFourierTransform.Norm normalization,
final boolean inverse) { final boolean inverse) {
final FastFourierTransform fft = new FastFourierTransform(normalization, inverse); final FastFourierTransform fft = new FastFourierTransform(normalization, inverse);
@ -250,14 +222,13 @@ public final class FastFourierTransformerTest {
} }
final Complex[] actual = fft.apply(x); final Complex[] actual = fft.apply(x);
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
final String msg; final int index = i;
msg = String.format("%s, %s, %d, %d", normalization, inverse, n, i);
final double re = s * expected[i].getReal(); final double re = s * expected[i].getReal();
Assert.assertEquals(msg, re, actual[i].getReal(), assertEqualsRelativeOrAbsolute(re, actual[i].getReal(), tol, absTol,
tol * Math.abs(re)); () -> String.format("%s, %s, %d, %d", normalization, inverse, n, index));
final double im = s * expected[i].getImaginary(); final double im = s * expected[i].getImaginary();
Assert.assertEquals(msg, im, actual[i].getImaginary(), assertEqualsRelativeOrAbsolute(im, actual[i].getImaginary(), tol, absTol,
tol * Math.abs(re)); () -> String.format("%s, %s, %d, %d", normalization, inverse, n, index));
} }
} }
@ -266,6 +237,7 @@ public final class FastFourierTransformerTest {
final double max, final double max,
int n, int n,
final double tol, final double tol,
final double absTol,
final FastFourierTransform.Norm normalization, final FastFourierTransform.Norm normalization,
final boolean inverse) { final boolean inverse) {
final FastFourierTransform fft = new FastFourierTransform(normalization, inverse); final FastFourierTransform fft = new FastFourierTransform(normalization, inverse);
@ -293,13 +265,27 @@ public final class FastFourierTransformerTest {
} }
final Complex[] actual = fft.apply(f, min, max, n); final Complex[] actual = fft.apply(f, min, max, n);
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
final String msg = String.format("%d, %d", n, i); final int index = i;
final double re = s * expected[i].getReal(); final double re = s * expected[i].getReal();
Assert.assertEquals(msg, re, actual[i].getReal(), assertEqualsRelativeOrAbsolute(re, actual[i].getReal(), tol, absTol,
tol * Math.abs(re)); () -> String.format("%s, %s, %d, %d", normalization, inverse, n, index));
final double im = s * expected[i].getImaginary(); final double im = s * expected[i].getImaginary();
Assert.assertEquals(msg, im, actual[i].getImaginary(), assertEqualsRelativeOrAbsolute(im, actual[i].getImaginary(), tol, absTol,
tol * Math.abs(re)); () -> String.format("%s, %s, %d, %d", normalization, inverse, n, index));
}
}
private static void assertEqualsRelativeOrAbsolute(double expected, double actual,
double relativeTolerance, double absoluteTolerance, Supplier<String> msg) {
if (!(Precision.equals(expected, actual, absoluteTolerance) ||
Precision.equalsWithRelativeTolerance(expected, actual, relativeTolerance))) {
// Custom supplier to provide relative and absolute error
final double absoluteMax = Math.max(Math.abs(expected), Math.abs(actual));
final double abs = Math.abs(expected - actual);
final double rel = abs / absoluteMax;
final Supplier<String> message = () -> String.format("%s: rel=%s, abs=%s", msg.get(), rel, abs);
// Re-use assertEquals to obtain the formatting
Assertions.assertEquals(expected, actual, message);
} }
} }
@ -310,13 +296,13 @@ public final class FastFourierTransformerTest {
final FastFourierTransform.Norm[] norm = FastFourierTransform.Norm.values(); final FastFourierTransform.Norm[] norm = FastFourierTransform.Norm.values();
for (int i = 0; i < norm.length; i++) { for (int i = 0; i < norm.length; i++) {
for (boolean type : new boolean[] {true, false}) { for (boolean type : new boolean[] {true, false}) {
doTestTransformComplex(2, 1e-15, norm[i], type); doTestTransformComplex(2, 1e-15, EPSILON, norm[i], type);
doTestTransformComplex(4, 1e-14, norm[i], type); doTestTransformComplex(4, 1e-14, EPSILON, norm[i], type);
doTestTransformComplex(8, 1e-13, norm[i], type); doTestTransformComplex(8, 1e-13, EPSILON, norm[i], type);
doTestTransformComplex(16, 1e-13, norm[i], type); doTestTransformComplex(16, 1e-13, EPSILON, norm[i], type);
doTestTransformComplex(32, 1e-12, norm[i], type); doTestTransformComplex(32, 1e-12, EPSILON, norm[i], type);
doTestTransformComplex(64, 1e-11, norm[i], type); doTestTransformComplex(64, 1e-11, EPSILON, norm[i], type);
doTestTransformComplex(128, 1e-11, norm[i], type); doTestTransformComplex(128, 1e-11, EPSILON, norm[i], type);
} }
} }
} }
@ -326,13 +312,13 @@ public final class FastFourierTransformerTest {
final FastFourierTransform.Norm[] norm = FastFourierTransform.Norm.values(); final FastFourierTransform.Norm[] norm = FastFourierTransform.Norm.values();
for (int i = 0; i < norm.length; i++) { for (int i = 0; i < norm.length; i++) {
for (boolean type : new boolean[] {true, false}) { for (boolean type : new boolean[] {true, false}) {
doTestTransformReal(2, 1e-15, norm[i], type); doTestTransformReal(2, 1e-15, EPSILON, norm[i], type);
doTestTransformReal(4, 1e-14, norm[i], type); doTestTransformReal(4, 1e-14, EPSILON, norm[i], type);
doTestTransformReal(8, 1e-13, norm[i], type); doTestTransformReal(8, 1e-13, 2 * EPSILON, norm[i], type);
doTestTransformReal(16, 1e-13, norm[i], type); doTestTransformReal(16, 1e-13, 2 * EPSILON, norm[i], type);
doTestTransformReal(32, 1e-12, norm[i], type); doTestTransformReal(32, 1e-12, 4 * EPSILON, norm[i], type);
doTestTransformReal(64, 1e-12, norm[i], type); doTestTransformReal(64, 1e-12, 4 * EPSILON, norm[i], type);
doTestTransformReal(128, 1e-11, norm[i], type); doTestTransformReal(128, 1e-11, 8 * EPSILON, norm[i], type);
} }
} }
} }
@ -348,13 +334,13 @@ public final class FastFourierTransformerTest {
for (int i = 0; i < norm.length; i++) { for (int i = 0; i < norm.length; i++) {
for (boolean type : new boolean[] {true, false}) { for (boolean type : new boolean[] {true, false}) {
doTestTransformFunction(f, min, max, 2, 1e-15, norm[i], type); doTestTransformFunction(f, min, max, 2, 1e-15, 4 * EPSILON, norm[i], type);
doTestTransformFunction(f, min, max, 4, 1e-14, norm[i], type); doTestTransformFunction(f, min, max, 4, 1e-14, 4 * EPSILON, norm[i], type);
doTestTransformFunction(f, min, max, 8, 1e-14, norm[i], type); doTestTransformFunction(f, min, max, 8, 1e-14, 4 * EPSILON, norm[i], type);
doTestTransformFunction(f, min, max, 16, 1e-13, norm[i], type); doTestTransformFunction(f, min, max, 16, 1e-13, 4 * EPSILON, norm[i], type);
doTestTransformFunction(f, min, max, 32, 1e-13, norm[i], type); doTestTransformFunction(f, min, max, 32, 1e-13, 8 * EPSILON, norm[i], type);
doTestTransformFunction(f, min, max, 64, 1e-12, norm[i], type); doTestTransformFunction(f, min, max, 64, 1e-12, 16 * EPSILON, norm[i], type);
doTestTransformFunction(f, min, max, 128, 1e-11, norm[i], type); doTestTransformFunction(f, min, max, 128, 1e-11, 64 * EPSILON, norm[i], type);
} }
} }
} }
@ -384,15 +370,15 @@ public final class FastFourierTransformerTest {
transformer = new FastFourierTransform(FastFourierTransform.Norm.STD); transformer = new FastFourierTransform(FastFourierTransform.Norm.STD);
result = transformer.apply(x); result = transformer.apply(x);
for (int i = 0; i < result.length; i++) { for (int i = 0; i < result.length; i++) {
Assert.assertEquals(y[i].getReal(), result[i].getReal(), tolerance); Assertions.assertEquals(y[i].getReal(), result[i].getReal(), tolerance);
Assert.assertEquals(y[i].getImaginary(), result[i].getImaginary(), tolerance); Assertions.assertEquals(y[i].getImaginary(), result[i].getImaginary(), tolerance);
} }
transformer = new FastFourierTransform(FastFourierTransform.Norm.STD, true); transformer = new FastFourierTransform(FastFourierTransform.Norm.STD, true);
result = transformer.apply(y); result = transformer.apply(y);
for (int i = 0; i < result.length; i++) { for (int i = 0; i < result.length; i++) {
Assert.assertEquals(x[i], result[i].getReal(), tolerance); Assertions.assertEquals(x[i], result[i].getReal(), tolerance);
Assert.assertEquals(0.0, result[i].getImaginary(), tolerance); Assertions.assertEquals(0.0, result[i].getImaginary(), tolerance);
} }
double[] x2 = {10.4, 21.6, 40.8, 13.6, 23.2, 32.8, 13.6, 19.2}; double[] x2 = {10.4, 21.6, 40.8, 13.6, 23.2, 32.8, 13.6, 19.2};
@ -402,15 +388,15 @@ public final class FastFourierTransformerTest {
transformer = new FastFourierTransform(FastFourierTransform.Norm.UNIT); transformer = new FastFourierTransform(FastFourierTransform.Norm.UNIT);
result = transformer.apply(y2); result = transformer.apply(y2);
for (int i = 0; i < result.length; i++) { for (int i = 0; i < result.length; i++) {
Assert.assertEquals(x2[i], result[i].getReal(), tolerance); Assertions.assertEquals(x2[i], result[i].getReal(), tolerance);
Assert.assertEquals(0.0, result[i].getImaginary(), tolerance); Assertions.assertEquals(0.0, result[i].getImaginary(), tolerance);
} }
transformer = new FastFourierTransform(FastFourierTransform.Norm.UNIT, true); transformer = new FastFourierTransform(FastFourierTransform.Norm.UNIT, true);
result = transformer.apply(x2); result = transformer.apply(x2);
for (int i = 0; i < result.length; i++) { for (int i = 0; i < result.length; i++) {
Assert.assertEquals(y2[i].getReal(), result[i].getReal(), tolerance); Assertions.assertEquals(y2[i].getReal(), result[i].getReal(), tolerance);
Assert.assertEquals(y2[i].getImaginary(), result[i].getImaginary(), tolerance); Assertions.assertEquals(y2[i].getImaginary(), result[i].getImaginary(), tolerance);
} }
} }
@ -428,26 +414,26 @@ public final class FastFourierTransformerTest {
double max = 2 * Math.PI; double max = 2 * Math.PI;
transformer = new FastFourierTransform(FastFourierTransform.Norm.STD); transformer = new FastFourierTransform(FastFourierTransform.Norm.STD);
result = transformer.apply(SIN, min, max, size); result = transformer.apply(SIN, min, max, size);
Assert.assertEquals(0.0, result[1].getReal(), tolerance); Assertions.assertEquals(0.0, result[1].getReal(), tolerance);
Assert.assertEquals(-(size >> 1), result[1].getImaginary(), tolerance); Assertions.assertEquals(-(size >> 1), result[1].getImaginary(), tolerance);
Assert.assertEquals(0.0, result[size - 1].getReal(), tolerance); Assertions.assertEquals(0.0, result[size - 1].getReal(), tolerance);
Assert.assertEquals(size >> 1, result[size - 1].getImaginary(), tolerance); Assertions.assertEquals(size >> 1, result[size - 1].getImaginary(), tolerance);
for (int i = 0; i < size - 1; i += i == 0 ? 2 : 1) { for (int i = 0; i < size - 1; i += i == 0 ? 2 : 1) {
Assert.assertEquals(0.0, result[i].getReal(), tolerance); Assertions.assertEquals(0.0, result[i].getReal(), tolerance);
Assert.assertEquals(0.0, result[i].getImaginary(), tolerance); Assertions.assertEquals(0.0, result[i].getImaginary(), tolerance);
} }
min = -Math.PI; min = -Math.PI;
max = Math.PI; max = Math.PI;
transformer = new FastFourierTransform(FastFourierTransform.Norm.STD, true); transformer = new FastFourierTransform(FastFourierTransform.Norm.STD, true);
result = transformer.apply(SIN, min, max, size); result = transformer.apply(SIN, min, max, size);
Assert.assertEquals(0.0, result[1].getReal(), tolerance); Assertions.assertEquals(0.0, result[1].getReal(), tolerance);
Assert.assertEquals(-0.5, result[1].getImaginary(), tolerance); Assertions.assertEquals(-0.5, result[1].getImaginary(), tolerance);
Assert.assertEquals(0.0, result[size - 1].getReal(), tolerance); Assertions.assertEquals(0.0, result[size - 1].getReal(), tolerance);
Assert.assertEquals(0.5, result[size - 1].getImaginary(), tolerance); Assertions.assertEquals(0.5, result[size - 1].getImaginary(), tolerance);
for (int i = 0; i < size - 1; i += i == 0 ? 2 : 1) { for (int i = 0; i < size - 1; i += i == 0 ? 2 : 1) {
Assert.assertEquals(0.0, result[i].getReal(), tolerance); Assertions.assertEquals(0.0, result[i].getReal(), tolerance);
Assert.assertEquals(0.0, result[i].getImaginary(), tolerance); Assertions.assertEquals(0.0, result[i].getImaginary(), tolerance);
} }
} }
} }

View File

@ -47,4 +47,5 @@
<suppress checks="MethodLength" files=".*/Dfp.*Test.java" /> <suppress checks="MethodLength" files=".*/Dfp.*Test.java" />
<suppress checks="MethodLength" files=".*[/\\]AccurateMathTest\.java" /> <suppress checks="MethodLength" files=".*[/\\]AccurateMathTest\.java" />
<suppress checks="LocalFinalVariableName" files=".*[/\\]AccurateMathTest\.java" /> <suppress checks="LocalFinalVariableName" files=".*[/\\]AccurateMathTest\.java" />
<suppress checks="ParameterNumber" files=".*[/\\]FastFourierTransformerTest\.java" />
</suppressions> </suppressions>