diff --git a/pom.xml b/pom.xml index 4a8456eae..dcbbb3269 100644 --- a/pom.xml +++ b/pom.xml @@ -147,6 +147,9 @@ Rémi Arntzen + + Matt Adereth + Jared Becksfort diff --git a/src/changes/changes.xml b/src/changes/changes.xml index b14ec3620..536443cde 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -51,6 +51,9 @@ If the output is not quite correct, check for invisible trailing spaces! + + Added Kendall's tau correlation (KendallsCorrelation). + "EigenDecomposition" may have failed to compute the decomposition for certain non-symmetric matrices. Port of the respective bugfix in Jama-1.0.3. diff --git a/src/main/java/org/apache/commons/math3/stat/correlation/KendallsCorrelation.java b/src/main/java/org/apache/commons/math3/stat/correlation/KendallsCorrelation.java new file mode 100644 index 000000000..a828e785b --- /dev/null +++ b/src/main/java/org/apache/commons/math3/stat/correlation/KendallsCorrelation.java @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.stat.correlation; + +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.linear.BlockRealMatrix; +import org.apache.commons.math3.linear.MatrixUtils; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.Pair; + +import java.util.Arrays; +import java.util.Comparator; + +/** + * Implementation of Kendall's Tau-b rank correlation. + *

+ * A pair of observations (x1, y1) and + * (x2, y2) are considered concordant if + * x1 < x2 and y1 < y2 + * or x2 < x1 and y2 < y1. + * The pair is discordant if x1 < x2 and + * y2 < y1 or x2 < x1 and + * y1 < y2. If either x1 = x2 + * or y1 = y2, the pair is neither concordant nor + * discordant. + *

+ * Kendall's Tau-b is defined as: + *

+ * taub = (nc - nd) / sqrt((n0 - n1) * (n0 - n2))
+ * 
+ *

+ * where: + *

+ *

+ * This implementation uses the O(n log n) algorithm described in + * William R. Knight's 1966 paper "A Computer Method for Calculating + * Kendall's Tau with Ungrouped Data" in the Journal of the American + * Statistical Association. + * + * @see + * Kendall tau rank correlation coefficient (Wikipedia) + * @see A Computer + * Method for Calculating Kendall's Tau with Ungrouped Data + * + * @version $Id$ + * @since 3.3 + */ +public class KendallsCorrelation { + + /** correlation matrix */ + private final RealMatrix correlationMatrix; + + /** + * Create a KendallsCorrelation instance without data. + */ + public KendallsCorrelation() { + correlationMatrix = null; + } + + /** + * Create a KendallsCorrelation from a rectangular array + * whose columns represent values of variables to be correlated. + * + * @param data rectangular array with columns representing variables + * @throws IllegalArgumentException if the input data array is not + * rectangular with at least two rows and two columns. + */ + public KendallsCorrelation(double[][] data) { + this(MatrixUtils.createRealMatrix(data)); + } + + /** + * Create a KendallsCorrelation from a RealMatrix whose columns + * represent variables to be correlated. + * + * @param matrix matrix with columns representing variables to correlate + */ + public KendallsCorrelation(RealMatrix matrix) { + correlationMatrix = computeCorrelationMatrix(matrix); + } + + /** + * Returns the correlation matrix. + * + * @return correlation matrix + */ + public RealMatrix getCorrelationMatrix() { + return correlationMatrix; + } + + /** + * Computes the Kendall's Tau rank correlation matrix for the columns of + * the input matrix. + * + * @param matrix matrix with columns representing variables to correlate + * @return correlation matrix + */ + public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) { + int nVars = matrix.getColumnDimension(); + RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars); + for (int i = 0; i < nVars; i++) { + for (int j = 0; j < i; j++) { + double corr = correlation(matrix.getColumn(i), matrix.getColumn(j)); + outMatrix.setEntry(i, j, corr); + outMatrix.setEntry(j, i, corr); + } + outMatrix.setEntry(i, i, 1d); + } + return outMatrix; + } + + /** + * Computes the Kendall's Tau rank correlation matrix for the columns of + * the input rectangular array. The columns of the array represent values + * of variables to be correlated. + * + * @param matrix matrix with columns representing variables to correlate + * @return correlation matrix + */ + public RealMatrix computeCorrelationMatrix(final double[][] matrix) { + return computeCorrelationMatrix(new BlockRealMatrix(matrix)); + } + + /** + * Computes the Kendall's Tau rank correlation coefficient between the two arrays. + * + * @param xArray first data array + * @param yArray second data array + * @return Returns Kendall's Tau rank correlation coefficient for the two arrays + * @throws DimensionMismatchException if the arrays lengths do not match + */ + public double correlation(final double[] xArray, final double[] yArray) + throws DimensionMismatchException { + + if (xArray.length != yArray.length) { + throw new DimensionMismatchException(xArray.length, yArray.length); + } + + final int n = xArray.length; + final int numPairs = n * (n - 1) / 2; + + @SuppressWarnings("unchecked") + Pair[] pairs = new Pair[n]; + for (int i = 0; i < n; i++) { + pairs[i] = new Pair(xArray[i], yArray[i]); + } + + Arrays.sort(pairs, new Comparator>() { + @Override + public int compare(Pair pair1, Pair pair2) { + int compareFirst = pair1.getFirst().compareTo(pair2.getFirst()); + return compareFirst != 0 ? compareFirst : pair1.getSecond().compareTo(pair2.getSecond()); + } + }); + + int tiedXPairs = 0; + int tiedXYPairs = 0; + int consecutiveXTies = 1; + int consecutiveXYTies = 1; + Pair prev = pairs[0]; + for (int i = 1; i < n; i++) { + final Pair curr = pairs[i]; + if (curr.getFirst().equals(prev.getFirst())) { + consecutiveXTies++; + if (curr.getSecond().equals(prev.getSecond())) { + consecutiveXYTies++; + } else { + tiedXYPairs += consecutiveXYTies * (consecutiveXYTies - 1) / 2; + consecutiveXYTies = 1; + } + } else { + tiedXPairs += consecutiveXTies * (consecutiveXTies - 1) / 2; + consecutiveXTies = 1; + tiedXYPairs += consecutiveXYTies * (consecutiveXYTies - 1) / 2; + consecutiveXYTies = 1; + } + prev = curr; + } + tiedXPairs += consecutiveXTies * (consecutiveXTies - 1) / 2; + tiedXYPairs += consecutiveXYTies * (consecutiveXYTies - 1) / 2; + + int swaps = 0; + @SuppressWarnings("unchecked") + Pair[] pairsDestination = new Pair[n]; + for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) { + for (int offset = 0; offset < n; offset += 2 * segmentSize) { + int i = offset; + final int iEnd = FastMath.min(i + segmentSize, n); + int j = iEnd; + final int jEnd = FastMath.min(j + segmentSize, n); + + int copyLocation = offset; + while (i < iEnd || j < jEnd) { + if (i < iEnd) { + if (j < jEnd) { + if (pairs[i].getSecond().compareTo(pairs[j].getSecond()) <= 0) { + pairsDestination[copyLocation] = pairs[i]; + i++; + } else { + pairsDestination[copyLocation] = pairs[j]; + j++; + swaps += iEnd - i; + } + } else { + pairsDestination[copyLocation] = pairs[i]; + i++; + } + } else { + pairsDestination[copyLocation] = pairs[j]; + j++; + } + copyLocation++; + } + } + final Pair[] pairsTemp = pairs; + pairs = pairsDestination; + pairsDestination = pairsTemp; + } + + int tiedYPairs = 0; + int consecutiveYTies = 1; + prev = pairs[0]; + for (int i = 1; i < n; i++) { + final Pair curr = pairs[i]; + if (curr.getSecond().equals(prev.getSecond())) { + consecutiveYTies++; + } else { + tiedYPairs += consecutiveYTies * (consecutiveYTies - 1) / 2; + consecutiveYTies = 1; + } + prev = curr; + } + tiedYPairs += consecutiveYTies * (consecutiveYTies - 1) / 2; + + int concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps; + return concordantMinusDiscordant / FastMath.sqrt((numPairs - tiedXPairs) * (numPairs - tiedYPairs)); + } +} diff --git a/src/test/java/org/apache/commons/math3/stat/correlation/KendallsCorrelationTest.java b/src/test/java/org/apache/commons/math3/stat/correlation/KendallsCorrelationTest.java new file mode 100644 index 000000000..2fc481b22 --- /dev/null +++ b/src/test/java/org/apache/commons/math3/stat/correlation/KendallsCorrelationTest.java @@ -0,0 +1,235 @@ +package org.apache.commons.math3.stat.correlation; + +import org.apache.commons.math3.TestUtils; +import org.apache.commons.math3.linear.BlockRealMatrix; +import org.apache.commons.math3.linear.RealMatrix; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +/** + * Test cases for Kendall's Tau rank correlation. + */ +public class KendallsCorrelationTest extends PearsonsCorrelationTest { + + private KendallsCorrelation correlation; + + @Before + public void setUp() { + correlation = new KendallsCorrelation(); + } + + /** + * Test Longley dataset against R. + */ + @Override + @Test + public void testLongly() { + RealMatrix matrix = createRealMatrix(longleyData, 16, 7); + KendallsCorrelation corrInstance = new KendallsCorrelation(matrix); + RealMatrix correlationMatrix = corrInstance.getCorrelationMatrix(); + double[] rData = new double[] { + 1, 0.9166666666666666, 0.9333333333333332, 0.3666666666666666, 0.05, 0.8999999999999999, + 0.8999999999999999, 0.9166666666666666, 1, 0.9833333333333333, 0.45, 0.03333333333333333, + 0.9833333333333333, 0.9833333333333333, 0.9333333333333332, 0.9833333333333333, 1, + 0.4333333333333333, 0.05, 0.9666666666666666, 0.9666666666666666, 0.3666666666666666, + 0.45, 0.4333333333333333, 1, -0.2166666666666666, 0.4666666666666666, 0.4666666666666666, 0.05, + 0.03333333333333333, 0.05, -0.2166666666666666, 1, 0.05, 0.05, 0.8999999999999999, 0.9833333333333333, + 0.9666666666666666, 0.4666666666666666, 0.05, 1, 0.9999999999999999, 0.8999999999999999, + 0.9833333333333333, 0.9666666666666666, 0.4666666666666666, 0.05, 0.9999999999999999, 1 + }; + TestUtils.assertEquals("Kendall's correlation matrix", createRealMatrix(rData, 7, 7), correlationMatrix, 10E-15); + } + + /** + * Test R swiss fertility dataset. + */ + @Test + public void testSwiss() { + RealMatrix matrix = createRealMatrix(swissData, 47, 5); + KendallsCorrelation corrInstance = new KendallsCorrelation(matrix); + RealMatrix correlationMatrix = corrInstance.getCorrelationMatrix(); + double[] rData = new double[] { + 1, 0.1795465254708308, -0.4762437404200669, -0.3306111613580587, 0.2453703703703704, + 0.1795465254708308, 1, -0.4505221560842292, -0.4761645631778491, 0.2054604569820847, + -0.4762437404200669, -0.4505221560842292, 1, 0.528943683925829, -0.3212755391722673, + -0.3306111613580587, -0.4761645631778491, 0.528943683925829, 1, -0.08479652265379604, + 0.2453703703703704, 0.2054604569820847, -0.3212755391722673, -0.08479652265379604, 1 + }; + TestUtils.assertEquals("Kendall's correlation matrix", createRealMatrix(rData, 5, 5), correlationMatrix, 10E-15); + } + + @Test + public void testSimpleOrdered() { + final int length = 10; + final double[] xArray = new double[length]; + final double[] yArray = new double[length]; + for (int i = 0; i < length; i++) { + xArray[i] = i; + yArray[i] = i; + } + Assert.assertEquals(1.0, correlation.correlation(xArray, yArray), Double.MIN_VALUE); + } + + @Test + public void testSimpleReversed() { + final int length = 10; + final double[] xArray = new double[length]; + final double[] yArray = new double[length]; + for (int i = 0; i < length; i++) { + xArray[length - i - 1] = i; + yArray[i] = i; + } + Assert.assertEquals(-1.0, correlation.correlation(xArray, yArray), Double.MIN_VALUE); + } + + @Test + public void testSimpleOrderedPowerOf2() { + final int length = 16; + final double[] xArray = new double[length]; + final double[] yArray = new double[length]; + for (int i = 0; i < length; i++) { + xArray[i] = i; + yArray[i] = i; + } + Assert.assertEquals(1.0, correlation.correlation(xArray, yArray), Double.MIN_VALUE); + } + + @Test + public void testSimpleReversedPowerOf2() { + final int length = 16; + final double[] xArray = new double[length]; + final double[] yArray = new double[length]; + for (int i = 0; i < length; i++) { + xArray[length - i - 1] = i; + yArray[i] = i; + } + Assert.assertEquals(-1.0, correlation.correlation(xArray, yArray), Double.MIN_VALUE); + } + + @Test + public void testSimpleJumble() { + // A B C D + final double[] xArray = new double[] {1.0, 2.0, 3.0, 4.0}; + final double[] yArray = new double[] {1.0, 3.0, 2.0, 4.0}; + + // 6 pairs: (A,B) (A,C) (A,D) (B,C) (B,D) (C,D) + // (B,C) is discordant, the other 5 are concordant + + Assert.assertEquals((5 - 1) / (double) 6, + correlation.correlation(xArray, yArray), + Double.MIN_VALUE); + } + + @Test + public void testBalancedJumble() { + // A B C D + final double[] xArray = new double[] {1.0, 2.0, 3.0, 4.0}; + final double[] yArray = new double[] {1.0, 4.0, 3.0, 2.0}; + + // 6 pairs: (A,B) (A,C) (A,D) (B,C) (B,D) (C,D) + // (A,B) (A,C), (A,D) are concordant, the other 3 are discordant + + Assert.assertEquals(0.0, + correlation.correlation(xArray, yArray), + Double.MIN_VALUE); + } + + @Test + public void testOrderedTies() { + final int length = 10; + final double[] xArray = new double[length]; + final double[] yArray = new double[length]; + for (int i = 0; i < length; i++) { + xArray[i] = i / 2; + yArray[i] = i / 2; + } + // 5 pairs of points that are tied in both values. + // 16 + 12 + 8 + 4 = 40 concordant + // (40 - 0) / Math.sqrt((45 - 5) * (45 - 5)) = 1 + Assert.assertEquals(1.0, correlation.correlation(xArray, yArray), Double.MIN_VALUE); + } + + + @Test + public void testAllTiesInBoth() { + final int length = 10; + final double[] xArray = new double[length]; + final double[] yArray = new double[length]; + Assert.assertEquals(Double.NaN, correlation.correlation(xArray, yArray), 0); + } + + @Test + public void testAllTiesInX() { + final int length = 10; + final double[] xArray = new double[length]; + final double[] yArray = new double[length]; + for (int i = 0; i < length; i++) { + xArray[i] = i; + } + Assert.assertEquals(Double.NaN, correlation.correlation(xArray, yArray), 0); + } + + @Test + public void testAllTiesInY() { + final int length = 10; + final double[] xArray = new double[length]; + final double[] yArray = new double[length]; + for (int i = 0; i < length; i++) { + yArray[i] = i; + } + Assert.assertEquals(Double.NaN, correlation.correlation(xArray, yArray), 0); + } + + @Test + public void testSingleElement() { + final int length = 1; + final double[] xArray = new double[length]; + final double[] yArray = new double[length]; + Assert.assertEquals(Double.NaN, correlation.correlation(xArray, yArray), 0); + } + + @Test + public void testTwoElements() { + final double[] xArray = new double[] {2.0, 1.0}; + final double[] yArray = new double[] {1.0, 2.0}; + Assert.assertEquals(-1.0, correlation.correlation(xArray, yArray), Double.MIN_VALUE); + } + + @Test + public void test2dDoubleArray() { + final double[][] input = new double[][] { + new double[] {2.0, 1.0, 2.0}, + new double[] {1.0, 2.0, 1.0}, + new double[] {0.0, 0.0, 0.0} + }; + + final double[][] expected = new double[][] { + new double[] {1.0, 1.0 / 3.0, 1.0}, + new double[] {1.0 / 3.0, 1.0, 1.0 / 3.0}, + new double[] {1.0, 1.0 / 3.0, 1.0}}; + + Assert.assertEquals(correlation.computeCorrelationMatrix(input), + new BlockRealMatrix(expected)); + + } + + @Test + public void testBlockMatrix() { + final double[][] input = new double[][] { + new double[] {2.0, 1.0, 2.0}, + new double[] {1.0, 2.0, 1.0}, + new double[] {0.0, 0.0, 0.0} + }; + + final double[][] expected = new double[][] { + new double[] {1.0, 1.0 / 3.0, 1.0}, + new double[] {1.0 / 3.0, 1.0, 1.0 / 3.0}, + new double[] {1.0, 1.0 / 3.0, 1.0}}; + + Assert.assertEquals( + correlation.computeCorrelationMatrix(new BlockRealMatrix(input)), + new BlockRealMatrix(expected)); + } + +}